File size: 6,074 Bytes
0873e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import argparse
from safetensors import safe_open
from collections import Counter
import os
import math # math.prod is Python 3.8+

# --- Dtype to Bytes Mapping ---
# Safetensors Dtype strings:
# BOOL, F8_E5M2, F8_E4M3FN, F16, BF16, F32, F64,
# I8, I16, I32, I64, U8, U16, U32, U64,
# F8_E5M2FNUZ, F8_E4M3FNUZ
DTYPE_TO_BYTES = {
    "BOOL": 1,
    # Float8 variants
    "F8_E5M2": 1,
    "F8E5M2": 1,        # Common alternative naming
    "F8_E4M3FN": 1,
    "F8E4M3FN": 1,      # Common alternative naming
    "F8_E4M3": 1,       # As seen in user example, likely E4M3FN
    "F8_E5M2FNUZ": 1,
    "F8E5M2FNUZ": 1,    # Common alternative naming
    "F8_E4M3FNUZ": 1,
    "F8E4M3FNUZ": 1,    # Common alternative naming
    # Standard floats
    "F16": 2,
    "BF16": 2,
    "F32": 4,
    "F64": 8,
    # Integers
    "I8": 1,
    "I16": 2,
    "I32": 4,
    "I64": 8,
    # Unsigned Integers
    "U8": 1,
    "U16": 2,
    "U32": 4,
    "U64": 8,
}

def get_bytes_per_element(dtype_str):
    """Returns the number of bytes for a given safetensors dtype string."""
    return DTYPE_TO_BYTES.get(dtype_str.upper(), None)

def calculate_num_elements(shape):
    """Calculates the total number of elements from a tensor shape tuple."""
    if not shape:  # Scalar tensor (shape is ())
        return 1
    if 0 in shape: # If any dimension is 0, total elements is 0
        return 0
    # Using math.prod for conciseness if Python 3.8+
    # For broader compatibility, a loop can be used:
    num_elements = 1
    for dim_size in shape:
        num_elements *= dim_size
    return num_elements

def inspect_safetensors_precision_and_size(filepath):
    """
    Reads a .safetensors file, iterates through its tensors,
    and reports the precision (dtype), actual size, and theoretical FP32 size.
    """
    if not os.path.exists(filepath):
        print(f"Error: File not found at '{filepath}'")
        return

    if not filepath.lower().endswith(".safetensors"):
        print(f"Warning: File '{filepath}' does not have a .safetensors extension. Attempting to read anyway.")

    tensor_info_list = []
    dtype_counts = Counter()
    total_actual_mb = 0.0
    total_fp32_equiv_mb = 0.0

    try:
        print(f"Inspecting tensors in: {filepath}\n")
        with safe_open(filepath, framework="pt", device="cpu") as f:
            tensor_keys = list(f.keys())
            if not tensor_keys:
                print("No tensors found in the file.")
                return

            max_key_len = len("Tensor Name") # Default/minimum
            if tensor_keys:
                 max_key_len = max(max_key_len, max(len(k) for k in tensor_keys))

            header = (
                f"{'Tensor Name':<{max_key_len}} | "
                f"{'Precision (dtype)':<17} | "
                f"{'Actual Size (MB)':>16} | "
                f"{'FP32 Equiv. (MB)':>18}"
            )
            print(header)
            print(
                f"{'-' * max_key_len}-|-------------------|------------------|-------------------"
            )

            for key in tensor_keys:
                tensor_slice = f.get_slice(key)
                dtype_str = tensor_slice.get_dtype()
                shape = tensor_slice.get_shape()

                num_elements = calculate_num_elements(shape)
                bytes_per_el_actual = get_bytes_per_element(dtype_str)

                actual_size_mb_str = "N/A"
                fp32_equiv_size_mb_str = "N/A"
                actual_size_mb_val = 0.0

                if bytes_per_el_actual is not None:
                    actual_bytes = num_elements * bytes_per_el_actual
                    actual_size_mb_val = actual_bytes / (1024 * 1024)
                    total_actual_mb += actual_size_mb_val
                    actual_size_mb_str = f"{actual_size_mb_val:.3f}"

                    # Theoretical FP32 size (FP32 is 4 bytes per element)
                    fp32_equiv_bytes = num_elements * 4
                    fp32_equiv_size_mb_val = fp32_equiv_bytes / (1024 * 1024)
                    total_fp32_equiv_mb += fp32_equiv_size_mb_val
                    fp32_equiv_size_mb_str = f"{fp32_equiv_size_mb_val:.3f}"
                else:
                    print(f"Warning: Unknown dtype '{dtype_str}' for tensor '{key}'. Cannot calculate size.")

                print(
                    f"{key:<{max_key_len}} | "
                    f"{dtype_str:<17} | "
                    f"{actual_size_mb_str:>16} | "
                    f"{fp32_equiv_size_mb_str:>18}"
                )
                dtype_counts[dtype_str] += 1

        print("\n--- Summary ---")
        print(f"Total tensors found: {len(tensor_keys)}")
        if dtype_counts:
            print("Precision distribution:")
            for dtype, count in dtype_counts.most_common():
                print(f"  - {dtype:<12}: {count} tensor(s)")
        else:
            print("No dtypes to summarize.")

        print(f"\nTotal actual size of all tensors: {total_actual_mb:.3f} MB")
        print(f"Total theoretical FP32 size of all tensors: {total_fp32_equiv_mb:.3f} MB")

        if total_fp32_equiv_mb > 0.00001: # Avoid division by zero or near-zero
            savings_percentage = (1 - (total_actual_mb / total_fp32_equiv_mb)) * 100
            print(f"Overall size reduction compared to full FP32: {savings_percentage:.2f}%")
        else:
            print("Overall size reduction cannot be calculated (no FP32 equivalent data or zero size).")

    except Exception as e:
        print(f"An error occurred while processing '{filepath}':")
        print(f"  {e}")
        print("Please ensure it's a valid .safetensors file and the 'safetensors' (and 'torch') libraries are installed correctly.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Inspect tensor precision (dtype) and size in a .safetensors file."
    )
    parser.add_argument(
        "filepath",
        help="Path to the .safetensors file to inspect."
    )
    args = parser.parse_args()

    inspect_safetensors_precision_and_size(args.filepath)