| import excepthook |
| import os.path |
| from functools import reduce |
| from pathlib import Path |
| import random |
|
|
| import hydra |
| import torch |
| from accelerate.utils import set_seed |
| from omegaconf import OmegaConf, DictConfig |
| from slider import Beatmap |
| from transformers.utils import cached_file |
|
|
| import osu_diffusion |
| import routed_pickle |
| from config import InferenceConfig, FidConfig |
| from diffusion_pipeline import DiffisionPipeline |
| from osuT5.osuT5.config import TrainConfig |
| from osuT5.osuT5.dataset.data_utils import events_of_type, TIMING_TYPES, merge_events |
| from osuT5.osuT5.inference import Preprocessor, Processor, Postprocessor, BeatmapConfig, GenerationConfig, \ |
| generation_config_from_beatmap, beatmap_config_from_beatmap, background_line |
| from osuT5.osuT5.inference.server import InferenceClient |
| from osuT5.osuT5.inference.super_timing_generator import SuperTimingGenerator |
| from osuT5.osuT5.model import Mapperatorinator |
| from osuT5.osuT5.tokenizer import Tokenizer, ContextType |
| from osuT5.osuT5.utils import get_model |
| from osu_diffusion import DiT_models |
| from osu_diffusion.config import DiffusionTrainConfig |
|
|
|
|
| def prepare_args(args: FidConfig | InferenceConfig): |
| if args.device == "auto": |
| if torch.cuda.is_available(): |
| print("Using CUDA for inference (auto-selected).") |
| args.device = "cuda" |
| elif torch.mps.is_available(): |
| print("Using MPS for inference (auto-selected).") |
| args.device = "mps" |
| else: |
| print("Using CPU for inference (auto-selected fallback).") |
| args.device = "cpu" |
| elif args.device != "cpu": |
| if args.device == "cuda": |
| if not torch.cuda.is_available(): |
| print("CUDA is not available. Falling back to CPU.") |
| args.device = "cpu" |
| elif args.device == "mps": |
| if not torch.mps.is_available(): |
| print("MPS is not available. Falling back to CPU.") |
| args.device = "cpu" |
| else: |
| print( |
| f"Requested device '{args.device}' not available. Falling back to CPU." |
| ) |
| args.device = "cpu" |
| torch.set_grad_enabled(False) |
| torch.set_float32_matmul_precision('high') |
| if args.seed is None: |
| args.seed = random.randint(0, 2 ** 16) |
| print(f"Random seed: {args.seed}") |
| set_seed(args.seed) |
|
|
|
|
| def autofill_paths(args: InferenceConfig): |
| """Autofills audio_path and output_path. Can be used either in Web GUI or CLI.""" |
| errors = [] |
|
|
| |
| beatmap_path = Path(args.beatmap_path) if args.beatmap_path else None |
| output_path = Path(args.output_path) if args.output_path else None |
| audio_path = Path(args.audio_path) if args.audio_path else None |
|
|
| |
| def is_valid_beatmap_file(path): |
| """Check if the file exists and has a valid beatmap extension (.osu).""" |
| if not path: |
| return True |
| return path.exists() and path.suffix.lower() == '.osu' |
|
|
| |
| if beatmap_path and is_valid_beatmap_file(beatmap_path): |
| try: |
| beatmap = Beatmap.from_path(beatmap_path) |
|
|
| |
| if not audio_path: |
| audio_path = beatmap_path.parent / beatmap.audio_filename |
|
|
| |
| if not output_path: |
| output_path = beatmap_path.parent |
|
|
| except Exception as e: |
| error_msg = f"Error reading beatmap file: {e}" |
| errors.append(error_msg) |
|
|
| |
| elif audio_path and audio_path.exists() and not output_path: |
| output_path = audio_path.parent |
|
|
| |
| valid_audio_extensions = {'.mp3', '.wav', '.ogg', '.m4a', '.flac'} |
| if not audio_path: |
| errors.append("Audio file path is required.") |
| elif not audio_path.exists(): |
| errors.append(f"Audio file not found: {audio_path}") |
| elif audio_path.suffix.lower() not in valid_audio_extensions: |
| errors.append(f"Audio file must have one of the following extensions: {', '.join(valid_audio_extensions)}: {audio_path}") |
|
|
| if beatmap_path: |
| if not beatmap_path.exists(): |
| errors.append(f"Beatmap file not found: {beatmap_path}") |
| elif not is_valid_beatmap_file(beatmap_path): |
| errors.append(f"Beatmap file must have .osu extension: {beatmap_path}") |
|
|
| |
| args.audio_path = str(audio_path) if audio_path else "" |
| args.output_path = str(output_path) if output_path else "" |
| args.beatmap_path = str(beatmap_path) if beatmap_path else "" |
|
|
| return { |
| 'success': len(errors) == 0, |
| 'errors': errors |
| } |
|
|
|
|
| def get_args_from_beatmap(args: InferenceConfig, tokenizer: Tokenizer): |
| result = autofill_paths(args) |
|
|
| if not result['success']: |
| for error in result['errors']: |
| print(f"Error: {error}") |
| raise ValueError("Invalid paths provided. Please check the errors above.") |
|
|
| if not args.beatmap_path: |
| |
| if args.gamemode is None: |
| args.gamemode = 0 |
| print(f"Using game mode {args.gamemode}") |
| if args.hp_drain_rate is None: |
| args.hp_drain_rate = 5 |
| print(f"Using HP drain rate {args.hp_drain_rate}") |
| if args.circle_size is None: |
| args.circle_size = 4 |
| print(f"Using circle size {args.circle_size}") |
| if args.overall_difficulty is None: |
| args.overall_difficulty = 8 |
| print(f"Using overall difficulty {args.overall_difficulty}") |
| if args.approach_rate is None: |
| args.approach_rate = 9 |
| print(f"Using approach rate {args.approach_rate}") |
| if args.slider_multiplier is None: |
| args.slider_multiplier = 1.4 |
| print(f"Using slider multiplier {args.slider_multiplier}") |
| if args.slider_tick_rate is None: |
| args.slider_tick_rate = 1 |
| print(f"Using slider tick rate {args.slider_tick_rate}") |
| if args.hitsounded is None: |
| args.hitsounded = True |
| print(f"Using hitsounded {args.hitsounded}") |
| if args.keycount is None and args.gamemode == 3: |
| args.keycount = 4 |
| print(f"Using keycount {args.keycount}") |
| return |
|
|
| beatmap_path = Path(args.beatmap_path) |
| beatmap = Beatmap.from_path(beatmap_path) |
|
|
| if beatmap.mode not in args.train.data.gamemodes and (any(c in [ContextType.MAP, ContextType.GD, ContextType.NO_HS] for c in args.in_context) or args.add_to_beatmap): |
| raise ValueError(f"Beatmap mode {beatmap.mode} is not supported by the model. Supported modes: {args.train.data.gamemodes}") |
|
|
| print(f"Using metadata from beatmap: {beatmap.display_name}") |
| generation_config = generation_config_from_beatmap(beatmap, tokenizer) |
|
|
| if args.gamemode is None: |
| args.gamemode = generation_config.gamemode |
| print(f"Using game mode {args.gamemode}") |
| if args.beatmap_id is None and generation_config.beatmap_id: |
| args.beatmap_id = generation_config.beatmap_id |
| print(f"Using beatmap ID {args.beatmap_id}") |
| if args.difficulty is None and generation_config.difficulty != -1 and len(beatmap.hit_objects(stacking=False)) > 0: |
| args.difficulty = generation_config.difficulty |
| print(f"Using difficulty {args.difficulty}") |
| if args.mapper_id is None and beatmap.beatmap_id in tokenizer.beatmap_mapper: |
| args.mapper_id = generation_config.mapper_id |
| print(f"Using mapper ID {args.mapper_id}") |
| if args.descriptors is None and beatmap.beatmap_id in tokenizer.beatmap_descriptors: |
| args.descriptors = generation_config.descriptors |
| print(f"Using descriptors {args.descriptors}") |
| if args.hp_drain_rate is None: |
| args.hp_drain_rate = generation_config.hp_drain_rate |
| print(f"Using HP drain rate {args.hp_drain_rate}") |
| if args.circle_size is None: |
| args.circle_size = generation_config.circle_size |
| print(f"Using circle size {args.circle_size}") |
| if args.overall_difficulty is None: |
| args.overall_difficulty = generation_config.overall_difficulty |
| print(f"Using overall difficulty {args.overall_difficulty}") |
| if args.approach_rate is None: |
| args.approach_rate = generation_config.approach_rate |
| print(f"Using approach rate {args.approach_rate}") |
| if args.slider_multiplier is None: |
| args.slider_multiplier = generation_config.slider_multiplier |
| print(f"Using slider multiplier {args.slider_multiplier}") |
| if args.slider_tick_rate is None: |
| args.slider_tick_rate = generation_config.slider_tick_rate |
| print(f"Using slider tick rate {args.slider_tick_rate}") |
| if args.hitsounded is None: |
| args.hitsounded = generation_config.hitsounded |
| print(f"Using hitsounded {args.hitsounded}") |
| if args.keycount is None and args.gamemode == 3: |
| args.keycount = int(generation_config.keycount) |
| print(f"Using keycount {args.keycount}") |
| if args.hold_note_ratio is None and args.gamemode == 3: |
| args.hold_note_ratio = generation_config.hold_note_ratio |
| print(f"Using hold note ratio {args.hold_note_ratio}") |
| if args.scroll_speed_ratio is None and args.gamemode == 3: |
| args.scroll_speed_ratio = generation_config.scroll_speed_ratio |
| print(f"Using scroll speed ratio {args.scroll_speed_ratio}") |
|
|
| beatmap_config = beatmap_config_from_beatmap(beatmap) |
|
|
| args.title = beatmap_config.title |
| args.artist = beatmap_config.artist |
| args.bpm = beatmap_config.bpm |
| args.offset = beatmap_config.offset |
| args.background = beatmap.background |
| args.preview_time = beatmap_config.preview_time |
|
|
|
|
| def get_tags_dict(args: DictConfig | InferenceConfig): |
| return dict( |
| lookback=args.lookback, |
| lookahead=args.lookahead, |
| beatmap_id=args.beatmap_id, |
| difficulty=args.difficulty, |
| mapper_id=args.mapper_id, |
| year=args.year, |
| hitsounded=args.hitsounded, |
| hold_note_ratio=args.hold_note_ratio, |
| scroll_speed_ratio=args.scroll_speed_ratio, |
| descriptors=f"\"[{','.join(args.descriptors)}]\"" if args.descriptors else None, |
| negative_descriptors=f"\"[{','.join(args.negative_descriptors)}]\"" if args.negative_descriptors else None, |
| timing_leniency=args.timing_leniency, |
| seed=args.seed, |
| add_to_beatmap=args.add_to_beatmap, |
| start_time=args.start_time, |
| end_time=args.end_time, |
| in_context=f"[{','.join(ctx.value.upper() if isinstance(ctx, ContextType) else ctx for ctx in args.in_context)}]", |
| cfg_scale=args.cfg_scale, |
| temperature=args.temperature, |
| timing_temperature=args.timing_temperature, |
| mania_column_temperature=args.mania_column_temperature, |
| taiko_hit_temperature=args.taiko_hit_temperature, |
| top_p=args.top_p, |
| top_k=args.top_k, |
| parallel=args.parallel, |
| do_sample=args.do_sample, |
| num_beams=args.num_beams, |
| super_timing=args.super_timing, |
| timer_num_beams=args.timer_num_beams, |
| timer_bpm_threshold=args.timer_bpm_threshold, |
| timer_cfg_scale=args.timer_cfg_scale, |
| timer_iterations=args.timer_iterations, |
| generate_positions=args.generate_positions, |
| diff_cfg_scale=args.diff_cfg_scale, |
| max_seq_len=args.max_seq_len, |
| overlap_buffer=args.overlap_buffer, |
| ) |
|
|
|
|
| def get_config(args: InferenceConfig): |
| |
| tags = get_tags_dict(args) |
| |
| defaults = get_tags_dict(OmegaConf.load("configs/inference/default.yaml")) |
| tags = {k: v for k, v in tags.items() if v != defaults[k]} |
| |
| tags = " ".join(f"{k}={v}" for k, v in tags.items()) |
|
|
| |
| return GenerationConfig( |
| gamemode=args.gamemode if args.gamemode is not None else 0, |
| beatmap_id=args.beatmap_id, |
| difficulty=args.difficulty, |
| mapper_id=args.mapper_id, |
| year=args.year, |
| hitsounded=args.hitsounded if args.hitsounded is not None else True, |
| hp_drain_rate=args.hp_drain_rate, |
| circle_size=args.circle_size, |
| overall_difficulty=args.overall_difficulty, |
| approach_rate=args.approach_rate, |
| slider_multiplier=args.slider_multiplier or 1.4, |
| slider_tick_rate=args.slider_tick_rate or 1, |
| keycount=args.keycount if args.keycount is not None else 4, |
| hold_note_ratio=args.hold_note_ratio, |
| scroll_speed_ratio=args.scroll_speed_ratio, |
| descriptors=args.descriptors, |
| negative_descriptors=args.negative_descriptors, |
| ), BeatmapConfig( |
| title=args.title, |
| artist=args.artist, |
| title_unicode=args.title, |
| artist_unicode=args.artist, |
| audio_filename=Path(args.audio_path).name, |
| hp_drain_rate=args.hp_drain_rate or 5, |
| circle_size=(args.keycount if args.gamemode == 3 else args.circle_size) or 4, |
| overall_difficulty=args.overall_difficulty or 8, |
| approach_rate=args.approach_rate or 9, |
| slider_multiplier=args.slider_multiplier or 1.4, |
| slider_tick_rate=args.slider_tick_rate or 1, |
| creator=args.creator, |
| version=args.version, |
| tags=tags, |
| background_line=background_line(args.background), |
| preview_time=args.preview_time, |
| bpm=args.bpm, |
| offset=args.offset, |
| mode=args.gamemode, |
| ) |
|
|
|
|
| def generate( |
| args: InferenceConfig, |
| *, |
| audio_path: str = None, |
| beatmap_path: str = None, |
| output_path: str = None, |
| generation_config: GenerationConfig, |
| beatmap_config: BeatmapConfig, |
| model: Mapperatorinator | InferenceClient, |
| tokenizer, |
| diff_model=None, |
| diff_tokenizer=None, |
| refine_model=None, |
| verbose=True, |
| ): |
| audio_path = args.audio_path if audio_path is None else audio_path |
| beatmap_path = args.beatmap_path if beatmap_path is None else beatmap_path |
| output_path = args.output_path if output_path is None else output_path |
|
|
| |
| if not Path(audio_path).exists() or not Path(audio_path).is_file(): |
| raise FileNotFoundError(f"Provided audio file path does not exist: {audio_path}") |
| if beatmap_path: |
| beatmap_path_obj = Path(beatmap_path) |
| if not beatmap_path_obj.exists() or not beatmap_path_obj.is_file(): |
| raise FileNotFoundError(f"Provided beatmap file path does not exist: {beatmap_path}") |
| |
| if beatmap_path_obj.suffix.lower() != '.osu': |
| raise ValueError(f"Beatmap file must have .osu extension: {beatmap_path}") |
|
|
| preprocessor = Preprocessor(args, parallel=args.parallel) |
| processor = Processor(args, model, tokenizer) |
| postprocessor = Postprocessor(args) |
|
|
| audio = preprocessor.load(audio_path) |
| sequences = preprocessor.segment(audio) |
| extra_in_context = {} |
| output_type = args.output_type.copy() |
|
|
| |
| timing_events, timing_times, timing = None, None, None |
| if args.super_timing and ContextType.NONE in args.in_context: |
| super_timing_generator = SuperTimingGenerator(args, model, tokenizer) |
| timing_events, timing_times = super_timing_generator.generate(audio, generation_config, verbose=verbose) |
| timing = postprocessor.generate_timing(timing_events) |
| extra_in_context[ContextType.TIMING] = timing |
| if ContextType.TIMING in output_type: |
| output_type.remove(ContextType.TIMING) |
| elif (ContextType.NONE in args.in_context and ContextType.MAP in output_type and |
| not any((ContextType.NONE in ctx["in"] or len(ctx["in"]) == 0) and ContextType.MAP in ctx["out"] for ctx in args.train.data.context_types)): |
| |
| timing_events, timing_times = processor.generate( |
| sequences=sequences, |
| generation_config=generation_config, |
| in_context=[ContextType.NONE], |
| out_context=[ContextType.TIMING], |
| verbose=verbose, |
| )[0] |
| timing_events, timing_times = events_of_type(timing_events, timing_times, TIMING_TYPES) |
| timing = postprocessor.generate_timing(timing_events) |
| extra_in_context[ContextType.TIMING] = timing |
| if ContextType.TIMING in output_type: |
| output_type.remove(ContextType.TIMING) |
| elif ContextType.TIMING in args.in_context or ( |
| args.train.data.add_timing and any(t in args.in_context for t in [ContextType.GD, ContextType.NO_HS])): |
| |
| timing = [tp for tp in Beatmap.from_path(Path(beatmap_path)).timing_points if tp.parent is None] |
|
|
| |
| if len(output_type) > 0: |
| result = processor.generate( |
| sequences=sequences, |
| generation_config=generation_config, |
| in_context=args.in_context, |
| out_context=output_type, |
| beatmap_path=beatmap_path, |
| extra_in_context=extra_in_context, |
| verbose=verbose, |
| ) |
|
|
| events, _ = reduce(merge_events, result) |
|
|
| if timing is None and (ContextType.TIMING in args.output_type or args.train.data.add_timing): |
| timing = postprocessor.generate_timing(events) |
|
|
| |
| if args.resnap_events and timing is not None: |
| events = postprocessor.resnap_events(events, timing) |
| else: |
| events = timing_events |
|
|
| |
| if args.generate_positions and args.gamemode in [0, 2] and ContextType.MAP in output_type: |
| diffusion_pipeline = DiffisionPipeline(args, diff_model, diff_tokenizer, refine_model) |
| events = diffusion_pipeline.generate( |
| events=events, |
| generation_config=generation_config, |
| timing=timing, |
| verbose=verbose, |
| ) |
|
|
| result = postprocessor.generate( |
| events=events, |
| beatmap_config=beatmap_config, |
| timing=timing, |
| ) |
|
|
| result_path = None |
| osz_path = None |
| if args.add_to_beatmap: |
| result_path = postprocessor.add_to_beatmap(result, beatmap_path) |
| if verbose: |
| print(f"Added generated content to {result_path}") |
| elif output_path is not None and output_path != "": |
| result_path = postprocessor.write_result(result, output_path) |
| if verbose: |
| print(f"Generated beatmap saved to {result_path}") |
|
|
| if args.export_osz: |
| osz_path = postprocessor.export_osz(result_path, audio_path, output_path) |
| if verbose: |
| print(f"Generated .osz saved to {osz_path}") |
|
|
| return result, result_path, osz_path |
|
|
|
|
| def load_model( |
| ckpt_path_str: str, |
| t5_args: TrainConfig, |
| device, |
| max_batch_size: int = 8, |
| use_server: bool = False, |
| precision: str = "fp32", |
| ): |
| if ckpt_path_str == "": |
| raise ValueError("Model path is empty.") |
|
|
| ckpt_path = Path(ckpt_path_str) |
|
|
| def tokenizer_loader(): |
| if not (ckpt_path / "pytorch_model.bin").exists() or not (ckpt_path / "custom_checkpoint_0.pkl").exists(): |
| tokenizer = Tokenizer.from_pretrained(ckpt_path_str) |
| else: |
| tokenizer_state = torch.load(ckpt_path / "custom_checkpoint_0.pkl", pickle_module=routed_pickle, weights_only=False) |
| tokenizer = Tokenizer() |
| tokenizer.load_state_dict(tokenizer_state) |
| return tokenizer |
|
|
| tokenizer = tokenizer_loader() |
|
|
| def model_loader(): |
| if not (ckpt_path / "pytorch_model.bin").exists() or not (ckpt_path / "custom_checkpoint_0.pkl").exists(): |
| model = Mapperatorinator.from_pretrained(ckpt_path_str) |
| model.generation_config.disable_compile = True |
| else: |
| model_state = torch.load(ckpt_path / "pytorch_model.bin", map_location=device, weights_only=True) |
| model = get_model(t5_args, tokenizer) |
| model.load_state_dict(model_state) |
|
|
| model.eval() |
| model.to(device) |
|
|
| if precision == "bf16": |
| |
| for name, module in model.named_modules(): |
| if name != "" and "spectrogram" not in name: |
| module.to(torch.bfloat16) |
|
|
| print(f"Model loaded: {ckpt_path_str} on device {device}") |
| return model |
|
|
| return InferenceClient( |
| model_loader, |
| tokenizer_loader, |
| max_batch_size=max_batch_size, |
| socket_path=get_server_address(ckpt_path_str), |
| ) if use_server else model_loader(), tokenizer |
|
|
|
|
| def get_server_address(ckpt_path_str: str): |
| """ |
| Get a valid socket address for the OS and model version. |
| """ |
| ckpt_path_str = ckpt_path_str.replace(" ", "_").replace("/", "_").replace("\\", "_").replace(".", "_") |
| |
| if os.name == 'posix': |
| |
| return f"/tmp/{ckpt_path_str}.sock" |
| else: |
| |
| return fr"\\.\pipe\{ckpt_path_str}" |
|
|
|
|
| def load_diff_model( |
| ckpt_path, |
| diff_args: DiffusionTrainConfig, |
| device, |
| ): |
| if not os.path.exists(ckpt_path) and ckpt_path != "": |
| tokenizer_file = cached_file(ckpt_path, "tokenizer.pkl") |
| model_file = cached_file(ckpt_path, "model_ema.pkl") |
| else: |
| ckpt_path = Path(ckpt_path) |
| tokenizer_file = ckpt_path / "tokenizer.pkl" |
| model_file = ckpt_path / "model_ema.pkl" |
|
|
| tokenizer_state = torch.load(tokenizer_file, pickle_module=routed_pickle, weights_only=False) |
| tokenizer = osu_diffusion.utils.tokenizer.Tokenizer() |
| tokenizer.load_state_dict(tokenizer_state) |
|
|
| ema_state = torch.load(model_file, pickle_module=routed_pickle, weights_only=False, map_location=device) |
| model = DiT_models[diff_args.model.model]( |
| context_size=diff_args.model.context_size, |
| class_size=tokenizer.num_tokens, |
| ).to(device) |
| model.load_state_dict(ema_state) |
| model.eval() |
| return model, tokenizer |
|
|
|
|
| @hydra.main(config_path="configs/inference", config_name="v30", version_base="1.1") |
| def main(args: InferenceConfig): |
| prepare_args(args) |
|
|
| model, tokenizer = load_model(args.model_path, args.train, args.device, args.max_batch_size, args.use_server, args.precision) |
|
|
| diff_model, diff_tokenizer, refine_model = None, None, None |
| if args.generate_positions: |
| diff_model, diff_tokenizer = load_diff_model(args.diff_ckpt, args.diffusion, args.device) |
|
|
| if os.path.exists(args.diff_refine_ckpt): |
| refine_model = load_diff_model(args.diff_refine_ckpt, args.diffusion, args.device)[0] |
|
|
| if args.compile: |
| diff_model.forward = torch.compile(diff_model.forward, mode="reduce-overhead", fullgraph=True) |
|
|
| get_args_from_beatmap(args, tokenizer) |
| generation_config, beatmap_config = get_config(args) |
|
|
| return generate( |
| args, |
| generation_config=generation_config, |
| beatmap_path=args.beatmap_path, |
| beatmap_config=beatmap_config, |
| model=model, |
| tokenizer=tokenizer, |
| diff_model=diff_model, |
| diff_tokenizer=diff_tokenizer, |
| refine_model=refine_model, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|