BeatHeritage-v1 / inference.py
fourmansyah's picture
Duplicate from hongminh54/BeatHeritage-v1
12a8e0f
import excepthook # noqa
import os.path
from functools import reduce
from pathlib import Path
import random
import hydra
import torch
from accelerate.utils import set_seed
from omegaconf import OmegaConf, DictConfig
from slider import Beatmap
from transformers.utils import cached_file
import osu_diffusion
import routed_pickle
from config import InferenceConfig, FidConfig
from diffusion_pipeline import DiffisionPipeline
from osuT5.osuT5.config import TrainConfig
from osuT5.osuT5.dataset.data_utils import events_of_type, TIMING_TYPES, merge_events
from osuT5.osuT5.inference import Preprocessor, Processor, Postprocessor, BeatmapConfig, GenerationConfig, \
generation_config_from_beatmap, beatmap_config_from_beatmap, background_line
from osuT5.osuT5.inference.server import InferenceClient
from osuT5.osuT5.inference.super_timing_generator import SuperTimingGenerator
from osuT5.osuT5.model import Mapperatorinator
from osuT5.osuT5.tokenizer import Tokenizer, ContextType
from osuT5.osuT5.utils import get_model
from osu_diffusion import DiT_models
from osu_diffusion.config import DiffusionTrainConfig
def prepare_args(args: FidConfig | InferenceConfig):
if args.device == "auto":
if torch.cuda.is_available():
print("Using CUDA for inference (auto-selected).")
args.device = "cuda"
elif torch.mps.is_available():
print("Using MPS for inference (auto-selected).")
args.device = "mps"
else:
print("Using CPU for inference (auto-selected fallback).")
args.device = "cpu"
elif args.device != "cpu":
if args.device == "cuda":
if not torch.cuda.is_available():
print("CUDA is not available. Falling back to CPU.")
args.device = "cpu"
elif args.device == "mps":
if not torch.mps.is_available():
print("MPS is not available. Falling back to CPU.")
args.device = "cpu"
else:
print(
f"Requested device '{args.device}' not available. Falling back to CPU."
)
args.device = "cpu"
torch.set_grad_enabled(False)
torch.set_float32_matmul_precision('high')
if args.seed is None:
args.seed = random.randint(0, 2 ** 16)
print(f"Random seed: {args.seed}")
set_seed(args.seed)
def autofill_paths(args: InferenceConfig):
"""Autofills audio_path and output_path. Can be used either in Web GUI or CLI."""
errors = []
# Convert paths to Path objects for easier manipulation
beatmap_path = Path(args.beatmap_path) if args.beatmap_path else None
output_path = Path(args.output_path) if args.output_path else None
audio_path = Path(args.audio_path) if args.audio_path else None
# Helper function to validate beatmap file type
def is_valid_beatmap_file(path):
"""Check if the file exists and has a valid beatmap extension (.osu)."""
if not path:
return True # Empty path is valid (optional)
return path.exists() and path.suffix.lower() == '.osu'
# Case 1: Beatmap path is provided - autofill audio and output
if beatmap_path and is_valid_beatmap_file(beatmap_path):
try:
beatmap = Beatmap.from_path(beatmap_path)
# Autofill audio path if empty
if not audio_path:
audio_path = beatmap_path.parent / beatmap.audio_filename
# Autofill output path if empty
if not output_path:
output_path = beatmap_path.parent
except Exception as e:
error_msg = f"Error reading beatmap file: {e}"
errors.append(error_msg)
# Case 2: Audio path is provided but no output path - autofill output
elif audio_path and audio_path.exists() and not output_path:
output_path = audio_path.parent
# Validate all paths
valid_audio_extensions = {'.mp3', '.wav', '.ogg', '.m4a', '.flac'}
if not audio_path:
errors.append("Audio file path is required.")
elif not audio_path.exists():
errors.append(f"Audio file not found: {audio_path}")
elif audio_path.suffix.lower() not in valid_audio_extensions:
errors.append(f"Audio file must have one of the following extensions: {', '.join(valid_audio_extensions)}: {audio_path}")
if beatmap_path:
if not beatmap_path.exists():
errors.append(f"Beatmap file not found: {beatmap_path}")
elif not is_valid_beatmap_file(beatmap_path):
errors.append(f"Beatmap file must have .osu extension: {beatmap_path}")
# Update args
args.audio_path = str(audio_path) if audio_path else ""
args.output_path = str(output_path) if output_path else ""
args.beatmap_path = str(beatmap_path) if beatmap_path else ""
return {
'success': len(errors) == 0,
'errors': errors
}
def get_args_from_beatmap(args: InferenceConfig, tokenizer: Tokenizer):
result = autofill_paths(args)
if not result['success']:
for error in result['errors']:
print(f"Error: {error}")
raise ValueError("Invalid paths provided. Please check the errors above.")
if not args.beatmap_path:
# populate fair defaults for any inherited args that need to be filled
if args.gamemode is None:
args.gamemode = 0
print(f"Using game mode {args.gamemode}")
if args.hp_drain_rate is None:
args.hp_drain_rate = 5
print(f"Using HP drain rate {args.hp_drain_rate}")
if args.circle_size is None:
args.circle_size = 4
print(f"Using circle size {args.circle_size}")
if args.overall_difficulty is None:
args.overall_difficulty = 8
print(f"Using overall difficulty {args.overall_difficulty}")
if args.approach_rate is None:
args.approach_rate = 9
print(f"Using approach rate {args.approach_rate}")
if args.slider_multiplier is None:
args.slider_multiplier = 1.4
print(f"Using slider multiplier {args.slider_multiplier}")
if args.slider_tick_rate is None:
args.slider_tick_rate = 1
print(f"Using slider tick rate {args.slider_tick_rate}")
if args.hitsounded is None:
args.hitsounded = True
print(f"Using hitsounded {args.hitsounded}")
if args.keycount is None and args.gamemode == 3:
args.keycount = 4
print(f"Using keycount {args.keycount}")
return
beatmap_path = Path(args.beatmap_path)
beatmap = Beatmap.from_path(beatmap_path)
if beatmap.mode not in args.train.data.gamemodes and (any(c in [ContextType.MAP, ContextType.GD, ContextType.NO_HS] for c in args.in_context) or args.add_to_beatmap):
raise ValueError(f"Beatmap mode {beatmap.mode} is not supported by the model. Supported modes: {args.train.data.gamemodes}")
print(f"Using metadata from beatmap: {beatmap.display_name}")
generation_config = generation_config_from_beatmap(beatmap, tokenizer)
if args.gamemode is None:
args.gamemode = generation_config.gamemode
print(f"Using game mode {args.gamemode}")
if args.beatmap_id is None and generation_config.beatmap_id:
args.beatmap_id = generation_config.beatmap_id
print(f"Using beatmap ID {args.beatmap_id}")
if args.difficulty is None and generation_config.difficulty != -1 and len(beatmap.hit_objects(stacking=False)) > 0:
args.difficulty = generation_config.difficulty
print(f"Using difficulty {args.difficulty}")
if args.mapper_id is None and beatmap.beatmap_id in tokenizer.beatmap_mapper:
args.mapper_id = generation_config.mapper_id
print(f"Using mapper ID {args.mapper_id}")
if args.descriptors is None and beatmap.beatmap_id in tokenizer.beatmap_descriptors:
args.descriptors = generation_config.descriptors
print(f"Using descriptors {args.descriptors}")
if args.hp_drain_rate is None:
args.hp_drain_rate = generation_config.hp_drain_rate
print(f"Using HP drain rate {args.hp_drain_rate}")
if args.circle_size is None:
args.circle_size = generation_config.circle_size
print(f"Using circle size {args.circle_size}")
if args.overall_difficulty is None:
args.overall_difficulty = generation_config.overall_difficulty
print(f"Using overall difficulty {args.overall_difficulty}")
if args.approach_rate is None:
args.approach_rate = generation_config.approach_rate
print(f"Using approach rate {args.approach_rate}")
if args.slider_multiplier is None:
args.slider_multiplier = generation_config.slider_multiplier
print(f"Using slider multiplier {args.slider_multiplier}")
if args.slider_tick_rate is None:
args.slider_tick_rate = generation_config.slider_tick_rate
print(f"Using slider tick rate {args.slider_tick_rate}")
if args.hitsounded is None:
args.hitsounded = generation_config.hitsounded
print(f"Using hitsounded {args.hitsounded}")
if args.keycount is None and args.gamemode == 3:
args.keycount = int(generation_config.keycount)
print(f"Using keycount {args.keycount}")
if args.hold_note_ratio is None and args.gamemode == 3:
args.hold_note_ratio = generation_config.hold_note_ratio
print(f"Using hold note ratio {args.hold_note_ratio}")
if args.scroll_speed_ratio is None and args.gamemode == 3:
args.scroll_speed_ratio = generation_config.scroll_speed_ratio
print(f"Using scroll speed ratio {args.scroll_speed_ratio}")
beatmap_config = beatmap_config_from_beatmap(beatmap)
args.title = beatmap_config.title
args.artist = beatmap_config.artist
args.bpm = beatmap_config.bpm
args.offset = beatmap_config.offset
args.background = beatmap.background
args.preview_time = beatmap_config.preview_time
def get_tags_dict(args: DictConfig | InferenceConfig):
return dict(
lookback=args.lookback,
lookahead=args.lookahead,
beatmap_id=args.beatmap_id,
difficulty=args.difficulty,
mapper_id=args.mapper_id,
year=args.year,
hitsounded=args.hitsounded,
hold_note_ratio=args.hold_note_ratio,
scroll_speed_ratio=args.scroll_speed_ratio,
descriptors=f"\"[{','.join(args.descriptors)}]\"" if args.descriptors else None,
negative_descriptors=f"\"[{','.join(args.negative_descriptors)}]\"" if args.negative_descriptors else None,
timing_leniency=args.timing_leniency,
seed=args.seed,
add_to_beatmap=args.add_to_beatmap,
start_time=args.start_time,
end_time=args.end_time,
in_context=f"[{','.join(ctx.value.upper() if isinstance(ctx, ContextType) else ctx for ctx in args.in_context)}]",
cfg_scale=args.cfg_scale,
temperature=args.temperature,
timing_temperature=args.timing_temperature,
mania_column_temperature=args.mania_column_temperature,
taiko_hit_temperature=args.taiko_hit_temperature,
top_p=args.top_p,
top_k=args.top_k,
parallel=args.parallel,
do_sample=args.do_sample,
num_beams=args.num_beams,
super_timing=args.super_timing,
timer_num_beams=args.timer_num_beams,
timer_bpm_threshold=args.timer_bpm_threshold,
timer_cfg_scale=args.timer_cfg_scale,
timer_iterations=args.timer_iterations,
generate_positions=args.generate_positions,
diff_cfg_scale=args.diff_cfg_scale,
max_seq_len=args.max_seq_len,
overlap_buffer=args.overlap_buffer,
)
def get_config(args: InferenceConfig):
# Create tags that describes args
tags = get_tags_dict(args)
# Filter to all non-default values
defaults = get_tags_dict(OmegaConf.load("configs/inference/default.yaml"))
tags = {k: v for k, v in tags.items() if v != defaults[k]}
# To string separated by spaces
tags = " ".join(f"{k}={v}" for k, v in tags.items())
# Set defaults for generation config that does not allow an unknown value
return GenerationConfig(
gamemode=args.gamemode if args.gamemode is not None else 0,
beatmap_id=args.beatmap_id,
difficulty=args.difficulty,
mapper_id=args.mapper_id,
year=args.year,
hitsounded=args.hitsounded if args.hitsounded is not None else True,
hp_drain_rate=args.hp_drain_rate,
circle_size=args.circle_size,
overall_difficulty=args.overall_difficulty,
approach_rate=args.approach_rate,
slider_multiplier=args.slider_multiplier or 1.4,
slider_tick_rate=args.slider_tick_rate or 1,
keycount=args.keycount if args.keycount is not None else 4,
hold_note_ratio=args.hold_note_ratio,
scroll_speed_ratio=args.scroll_speed_ratio,
descriptors=args.descriptors,
negative_descriptors=args.negative_descriptors,
), BeatmapConfig(
title=args.title,
artist=args.artist,
title_unicode=args.title,
artist_unicode=args.artist,
audio_filename=Path(args.audio_path).name,
hp_drain_rate=args.hp_drain_rate or 5,
circle_size=(args.keycount if args.gamemode == 3 else args.circle_size) or 4,
overall_difficulty=args.overall_difficulty or 8,
approach_rate=args.approach_rate or 9,
slider_multiplier=args.slider_multiplier or 1.4,
slider_tick_rate=args.slider_tick_rate or 1,
creator=args.creator,
version=args.version,
tags=tags,
background_line=background_line(args.background),
preview_time=args.preview_time,
bpm=args.bpm,
offset=args.offset,
mode=args.gamemode,
)
def generate(
args: InferenceConfig,
*,
audio_path: str = None,
beatmap_path: str = None,
output_path: str = None,
generation_config: GenerationConfig,
beatmap_config: BeatmapConfig,
model: Mapperatorinator | InferenceClient,
tokenizer,
diff_model=None,
diff_tokenizer=None,
refine_model=None,
verbose=True,
):
audio_path = args.audio_path if audio_path is None else audio_path
beatmap_path = args.beatmap_path if beatmap_path is None else beatmap_path
output_path = args.output_path if output_path is None else output_path
# Do some validation
if not Path(audio_path).exists() or not Path(audio_path).is_file():
raise FileNotFoundError(f"Provided audio file path does not exist: {audio_path}")
if beatmap_path:
beatmap_path_obj = Path(beatmap_path)
if not beatmap_path_obj.exists() or not beatmap_path_obj.is_file():
raise FileNotFoundError(f"Provided beatmap file path does not exist: {beatmap_path}")
# Validate beatmap file type
if beatmap_path_obj.suffix.lower() != '.osu':
raise ValueError(f"Beatmap file must have .osu extension: {beatmap_path}")
preprocessor = Preprocessor(args, parallel=args.parallel)
processor = Processor(args, model, tokenizer)
postprocessor = Postprocessor(args)
audio = preprocessor.load(audio_path)
sequences = preprocessor.segment(audio)
extra_in_context = {}
output_type = args.output_type.copy()
# Auto generate timing if not provided in in_context and required for the model and this output_type
timing_events, timing_times, timing = None, None, None
if args.super_timing and ContextType.NONE in args.in_context:
super_timing_generator = SuperTimingGenerator(args, model, tokenizer)
timing_events, timing_times = super_timing_generator.generate(audio, generation_config, verbose=verbose)
timing = postprocessor.generate_timing(timing_events)
extra_in_context[ContextType.TIMING] = timing
if ContextType.TIMING in output_type:
output_type.remove(ContextType.TIMING)
elif (ContextType.NONE in args.in_context and ContextType.MAP in output_type and
not any((ContextType.NONE in ctx["in"] or len(ctx["in"]) == 0) and ContextType.MAP in ctx["out"] for ctx in args.train.data.context_types)):
# Generate timing and convert in_context to timing context
timing_events, timing_times = processor.generate(
sequences=sequences,
generation_config=generation_config,
in_context=[ContextType.NONE],
out_context=[ContextType.TIMING],
verbose=verbose,
)[0]
timing_events, timing_times = events_of_type(timing_events, timing_times, TIMING_TYPES)
timing = postprocessor.generate_timing(timing_events)
extra_in_context[ContextType.TIMING] = timing
if ContextType.TIMING in output_type:
output_type.remove(ContextType.TIMING)
elif ContextType.TIMING in args.in_context or (
args.train.data.add_timing and any(t in args.in_context for t in [ContextType.GD, ContextType.NO_HS])):
# Exact timing is provided in the other beatmap, so we don't need to generate it
timing = [tp for tp in Beatmap.from_path(Path(beatmap_path)).timing_points if tp.parent is None]
# Generate beatmap
if len(output_type) > 0:
result = processor.generate(
sequences=sequences,
generation_config=generation_config,
in_context=args.in_context,
out_context=output_type,
beatmap_path=beatmap_path,
extra_in_context=extra_in_context,
verbose=verbose,
)
events, _ = reduce(merge_events, result)
if timing is None and (ContextType.TIMING in args.output_type or args.train.data.add_timing):
timing = postprocessor.generate_timing(events)
# Resnap timing events
if args.resnap_events and timing is not None:
events = postprocessor.resnap_events(events, timing)
else:
events = timing_events
# Generate positions with diffusion
if args.generate_positions and args.gamemode in [0, 2] and ContextType.MAP in output_type:
diffusion_pipeline = DiffisionPipeline(args, diff_model, diff_tokenizer, refine_model)
events = diffusion_pipeline.generate(
events=events,
generation_config=generation_config,
timing=timing,
verbose=verbose,
)
result = postprocessor.generate(
events=events,
beatmap_config=beatmap_config,
timing=timing,
)
result_path = None
osz_path = None
if args.add_to_beatmap:
result_path = postprocessor.add_to_beatmap(result, beatmap_path)
if verbose:
print(f"Added generated content to {result_path}")
elif output_path is not None and output_path != "":
result_path = postprocessor.write_result(result, output_path)
if verbose:
print(f"Generated beatmap saved to {result_path}")
if args.export_osz:
osz_path = postprocessor.export_osz(result_path, audio_path, output_path)
if verbose:
print(f"Generated .osz saved to {osz_path}")
return result, result_path, osz_path
def load_model(
ckpt_path_str: str,
t5_args: TrainConfig,
device,
max_batch_size: int = 8,
use_server: bool = False,
precision: str = "fp32",
):
if ckpt_path_str == "":
raise ValueError("Model path is empty.")
ckpt_path = Path(ckpt_path_str)
def tokenizer_loader():
if not (ckpt_path / "pytorch_model.bin").exists() or not (ckpt_path / "custom_checkpoint_0.pkl").exists():
tokenizer = Tokenizer.from_pretrained(ckpt_path_str)
else:
tokenizer_state = torch.load(ckpt_path / "custom_checkpoint_0.pkl", pickle_module=routed_pickle, weights_only=False)
tokenizer = Tokenizer()
tokenizer.load_state_dict(tokenizer_state)
return tokenizer
tokenizer = tokenizer_loader()
def model_loader():
if not (ckpt_path / "pytorch_model.bin").exists() or not (ckpt_path / "custom_checkpoint_0.pkl").exists():
model = Mapperatorinator.from_pretrained(ckpt_path_str)
model.generation_config.disable_compile = True
else:
model_state = torch.load(ckpt_path / "pytorch_model.bin", map_location=device, weights_only=True)
model = get_model(t5_args, tokenizer)
model.load_state_dict(model_state)
model.eval()
model.to(device)
if precision == "bf16":
# Cast every submodule to bfloat16 except for the spectrogram module
for name, module in model.named_modules():
if name != "" and "spectrogram" not in name:
module.to(torch.bfloat16)
print(f"Model loaded: {ckpt_path_str} on device {device}")
return model
return InferenceClient(
model_loader,
tokenizer_loader,
max_batch_size=max_batch_size,
socket_path=get_server_address(ckpt_path_str),
) if use_server else model_loader(), tokenizer
def get_server_address(ckpt_path_str: str):
"""
Get a valid socket address for the OS and model version.
"""
ckpt_path_str = ckpt_path_str.replace(" ", "_").replace("/", "_").replace("\\", "_").replace(".", "_")
# Check if the OS supports Unix sockets
if os.name == 'posix':
# Use a Unix socket for Linux and macOS
return f"/tmp/{ckpt_path_str}.sock"
else:
# Use a Windows named pipe
return fr"\\.\pipe\{ckpt_path_str}"
def load_diff_model(
ckpt_path,
diff_args: DiffusionTrainConfig,
device,
):
if not os.path.exists(ckpt_path) and ckpt_path != "":
tokenizer_file = cached_file(ckpt_path, "tokenizer.pkl")
model_file = cached_file(ckpt_path, "model_ema.pkl")
else:
ckpt_path = Path(ckpt_path)
tokenizer_file = ckpt_path / "tokenizer.pkl"
model_file = ckpt_path / "model_ema.pkl"
tokenizer_state = torch.load(tokenizer_file, pickle_module=routed_pickle, weights_only=False)
tokenizer = osu_diffusion.utils.tokenizer.Tokenizer()
tokenizer.load_state_dict(tokenizer_state)
ema_state = torch.load(model_file, pickle_module=routed_pickle, weights_only=False, map_location=device)
model = DiT_models[diff_args.model.model](
context_size=diff_args.model.context_size,
class_size=tokenizer.num_tokens,
).to(device)
model.load_state_dict(ema_state)
model.eval() # important!
return model, tokenizer
@hydra.main(config_path="configs/inference", config_name="v30", version_base="1.1")
def main(args: InferenceConfig):
prepare_args(args)
model, tokenizer = load_model(args.model_path, args.train, args.device, args.max_batch_size, args.use_server, args.precision)
diff_model, diff_tokenizer, refine_model = None, None, None
if args.generate_positions:
diff_model, diff_tokenizer = load_diff_model(args.diff_ckpt, args.diffusion, args.device)
if os.path.exists(args.diff_refine_ckpt):
refine_model = load_diff_model(args.diff_refine_ckpt, args.diffusion, args.device)[0]
if args.compile:
diff_model.forward = torch.compile(diff_model.forward, mode="reduce-overhead", fullgraph=True)
get_args_from_beatmap(args, tokenizer)
generation_config, beatmap_config = get_config(args)
return generate(
args,
generation_config=generation_config,
beatmap_path=args.beatmap_path,
beatmap_config=beatmap_config,
model=model,
tokenizer=tokenizer,
diff_model=diff_model,
diff_tokenizer=diff_tokenizer,
refine_model=refine_model,
)
if __name__ == "__main__":
main()