Commit ·
12a8e0f
0
Parent(s):
Duplicate from hongminh54/BeatHeritage-v1
Browse filesCo-authored-by: hongminh54 <hongminh54@users.noreply.huggingface.co>
This view is limited to 50 files because it contains too many changes. See raw diff
- .devcontainer/devcontainer.json +42 -0
- .devcontainer/docker-compose.yml +26 -0
- .gitattributes +35 -0
- .github/FUNDING.yml +15 -0
- .gitignore +11 -0
- Dockerfile +8 -0
- LICENSE +21 -0
- README.md +323 -0
- audit_all_configs.py +157 -0
- beatheritage_postprocessor.py +474 -0
- benchmark_comparison.py +469 -0
- calc_fid.py +417 -0
- classifier/README.md +34 -0
- classifier/classify.py +175 -0
- classifier/configs/inference.yaml +14 -0
- classifier/configs/model/model.yaml +9 -0
- classifier/configs/model/whisper_base.yaml +6 -0
- classifier/configs/model/whisper_base_v2.yaml +7 -0
- classifier/configs/model/whisper_small.yaml +6 -0
- classifier/configs/model/whisper_tiny.yaml +6 -0
- classifier/configs/train.yaml +82 -0
- classifier/configs/train_v1.yaml +4 -0
- classifier/configs/train_v2.yaml +14 -0
- classifier/configs/train_v3.yaml +17 -0
- classifier/count_classes.py +56 -0
- classifier/libs/__init__.py +1 -0
- classifier/libs/dataset/__init__.py +3 -0
- classifier/libs/dataset/data_utils.py +308 -0
- classifier/libs/dataset/ors_dataset.py +490 -0
- classifier/libs/dataset/osu_parser.py +460 -0
- classifier/libs/model/__init__.py +1 -0
- classifier/libs/model/model.py +145 -0
- classifier/libs/model/spectrogram.py +55 -0
- classifier/libs/tokenizer/__init__.py +2 -0
- classifier/libs/tokenizer/event.py +53 -0
- classifier/libs/tokenizer/tokenizer.py +201 -0
- classifier/libs/utils/__init__.py +1 -0
- classifier/libs/utils/model_utils.py +190 -0
- classifier/libs/utils/routed_pickle.py +17 -0
- classifier/test.py +32 -0
- classifier/train.py +82 -0
- cli_inference.sh +491 -0
- colab/beatheritage_v1_inference.ipynb +510 -0
- colab/classifier_classify.ipynb +133 -0
- colab/mai_mod_inference.ipynb +148 -0
- colab/mapperatorinator_inference.ipynb +305 -0
- collate_results.py +158 -0
- compose.yaml +25 -0
- config.py +197 -0
- configs/calc_fid.yaml +43 -0
.devcontainer/devcontainer.json
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
|
| 2 |
+
// README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-docker-compose
|
| 3 |
+
{
|
| 4 |
+
"name": "Existing Docker Compose (Extend)",
|
| 5 |
+
|
| 6 |
+
// Update the 'dockerComposeFile' list if you have more compose files or use different names.
|
| 7 |
+
// The .devcontainer/docker-compose.yml file contains any overrides you need/want to make.
|
| 8 |
+
"dockerComposeFile": [
|
| 9 |
+
"../compose.yaml",
|
| 10 |
+
"docker-compose.yml"
|
| 11 |
+
],
|
| 12 |
+
|
| 13 |
+
// The 'service' property is the name of the service for the container that VS Code should
|
| 14 |
+
// use. Update this value and .devcontainer/docker-compose.yml to the real service name.
|
| 15 |
+
"service": "Mapperatorinator",
|
| 16 |
+
|
| 17 |
+
// The optional 'workspaceFolder' property is the path VS Code should open by default when
|
| 18 |
+
// connected. This is typically a file mount in .devcontainer/docker-compose.yml
|
| 19 |
+
"workspaceFolder": "/workspace/Mapperatorinator",
|
| 20 |
+
// "workspaceFolder": "/",
|
| 21 |
+
|
| 22 |
+
// Features to add to the dev container. More info: https://containers.dev/features.
|
| 23 |
+
// "features": {},
|
| 24 |
+
|
| 25 |
+
// Use 'forwardPorts' to make a list of ports inside the container available locally.
|
| 26 |
+
// "forwardPorts": [],
|
| 27 |
+
|
| 28 |
+
// Uncomment the next line if you want start specific services in your Docker Compose config.
|
| 29 |
+
// "runServices": [],
|
| 30 |
+
|
| 31 |
+
// Uncomment the next line if you want to keep your containers running after VS Code shuts down.
|
| 32 |
+
// "shutdownAction": "none",
|
| 33 |
+
|
| 34 |
+
// Uncomment the next line to run commands after the container is created.
|
| 35 |
+
"postCreateCommand": "git config --global --add safe.directory /workspace/Mapperatorinator"
|
| 36 |
+
|
| 37 |
+
// Configure tool-specific properties.
|
| 38 |
+
// "customizations": {},
|
| 39 |
+
|
| 40 |
+
// Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root.
|
| 41 |
+
// "remoteUser": "devcontainer"
|
| 42 |
+
}
|
.devcontainer/docker-compose.yml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: '3.8'
|
| 2 |
+
services:
|
| 3 |
+
# Update this to the name of the service you want to work with in your docker-compose.yml file
|
| 4 |
+
mapperatorinator:
|
| 5 |
+
# Uncomment if you want to override the service's Dockerfile to one in the .devcontainer
|
| 6 |
+
# folder. Note that the path of the Dockerfile and context is relative to the *primary*
|
| 7 |
+
# docker-compose.yml file (the first in the devcontainer.json "dockerComposeFile"
|
| 8 |
+
# array). The sample below assumes your primary file is in the root of your project.
|
| 9 |
+
#
|
| 10 |
+
# build:
|
| 11 |
+
# context: .
|
| 12 |
+
# dockerfile: .devcontainer/Dockerfile
|
| 13 |
+
|
| 14 |
+
volumes:
|
| 15 |
+
# Update this to wherever you want VS Code to mount the folder of your project
|
| 16 |
+
- ..:/workspace:cached
|
| 17 |
+
|
| 18 |
+
# Uncomment the next four lines if you will use a ptrace-based debugger like C++, Go, and Rust.
|
| 19 |
+
# cap_add:
|
| 20 |
+
# - SYS_PTRACE
|
| 21 |
+
# security_opt:
|
| 22 |
+
# - seccomp:unconfined
|
| 23 |
+
|
| 24 |
+
# Overrides default command so things don't shut down after the process ends.
|
| 25 |
+
command: /bin/sh -c "while sleep 1000; do :; done"
|
| 26 |
+
|
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.github/FUNDING.yml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# These are supported funding model platforms
|
| 2 |
+
|
| 3 |
+
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
| 4 |
+
patreon: # Replace with a single Patreon username
|
| 5 |
+
open_collective: # Replace with a single Open Collective username
|
| 6 |
+
ko_fi: OliBomby
|
| 7 |
+
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
| 8 |
+
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
| 9 |
+
liberapay: # Replace with a single Liberapay username
|
| 10 |
+
issuehunt: # Replace with a single IssueHunt username
|
| 11 |
+
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
| 12 |
+
polar: # Replace with a single Polar username
|
| 13 |
+
buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
|
| 14 |
+
thanks_dev: # Replace with a single thanks.dev username
|
| 15 |
+
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
|
.gitignore
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv
|
| 2 |
+
__pycache__
|
| 3 |
+
logs
|
| 4 |
+
logs_fid
|
| 5 |
+
multirun
|
| 6 |
+
tensorboard_logs
|
| 7 |
+
.idea
|
| 8 |
+
test
|
| 9 |
+
test_inference.py
|
| 10 |
+
test_inference_mai_mod.py
|
| 11 |
+
.windsurf
|
Dockerfile
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
|
| 2 |
+
|
| 3 |
+
RUN apt-get -y update && apt-get -y upgrade && apt-get install -y git && apt-get install -y --no-install-recommends ffmpeg && rm -rf /var/lib/apt/lists/*
|
| 4 |
+
RUN pip install accelerate pydub nnAudio PyYAML transformers hydra-core tensorboard lightning pandas pyarrow einops 'git+https://github.com/OliBomby/slider.git@gedagedigedagedaoh#egg=slider' torch_tb_profiler wandb ninja
|
| 5 |
+
RUN MAX_JOBS=4 pip install flash-attn --no-build-isolation
|
| 6 |
+
|
| 7 |
+
# Modify .bashrc to include the custom prompt
|
| 8 |
+
RUN echo 'if [ -f /.dockerenv ]; then export PS1="(docker) $PS1"; fi' >> /root/.bashrc
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 OliBomby
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# BeatHeritage
|
| 2 |
+
|
| 3 |
+
🎯 **NEW: BeatHeritage V1 - Enhanced Stability & Quality** | [Try on Colab](https://colab.research.google.com/github/hongminh54/BeatHeritage/blob/main/colab/beatheritage_v1_inference.ipynb) | [Documentation](docs/BEATHERITAGE_V1.md)
|
| 4 |
+
|
| 5 |
+
Try the generative model [here](https://colab.research.google.com/github/hongminh54/BeatHeritage/blob/main/colab/beatheritage_v1_inference.ipynb), or MaiMod [here](https://colab.research.google.com/github/OliBomby/Mapperatorinator/blob/main/colab/mai_mod_inference.ipynb). Check out a video showcase [here](https://youtu.be/FEr7t1L2EoA).
|
| 6 |
+
|
| 7 |
+
BeatHeritage (formerly Mapperatorinator) is a multi-model framework that uses spectrogram inputs to generate fully featured osu! beatmaps for all gamemodes and [assist modding beatmaps](#maimod-the-ai-driven-modding-tool).
|
| 8 |
+
The goal of this project is to automatically generate rankable quality osu! beatmaps from any song with a high degree of customizability.
|
| 9 |
+
|
| 10 |
+
## 🚀 What's New in BeatHeritage V1
|
| 11 |
+
|
| 12 |
+
- **Enhanced Stability**: Optimized sampling parameters (temperature 0.85, top_p 0.92) for more consistent generation
|
| 13 |
+
- **Quality Control**: Automatic spacing correction, overlap detection, and flow optimization
|
| 14 |
+
- **Pattern Variety**: Advanced pattern generation with diversity enhancement
|
| 15 |
+
- **All Gamemodes**: Full support for std, taiko, ctb, and mania with mode-specific optimizations
|
| 16 |
+
- **Performance**: Flash attention, mixed precision (BF16), and gradient checkpointing
|
| 17 |
+
- **Custom Postprocessor**: Advanced post-processing with flow optimization and style preservation
|
| 18 |
+
- **Benchmark Tools**: Compare performance with previous models
|
| 19 |
+
- **Easy Setup**: Auto-setup script with model downloading from Hugging Face
|
| 20 |
+
|
| 21 |
+
This project is built upon [osuT5](https://github.com/gyataro/osuT5) and [osu-diffusion](https://github.com/OliBomby/osu-diffusion). In developing this, I spent about 2500 hours of GPU compute across 142 runs on my 4060 Ti and rented 4090 instances on vast.ai.
|
| 22 |
+
|
| 23 |
+
#### Use this tool responsibly. Always disclose the use of AI in your beatmaps.
|
| 24 |
+
|
| 25 |
+
## Installation
|
| 26 |
+
|
| 27 |
+
The instruction below allows you to generate beatmaps on your local machine, alternatively you can run it in the cloud with the [colab notebook](https://colab.research.google.com/github/OliBomby/Mapperatorinator/blob/main/colab/mapperatorinator_inference.ipynb).
|
| 28 |
+
|
| 29 |
+
### 1. Clone the repository
|
| 30 |
+
|
| 31 |
+
```sh
|
| 32 |
+
git clone https://github.com/OliBomby/Mapperatorinator.git
|
| 33 |
+
cd Mapperatorinator
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
### 2. (Optional) Create virtual environment
|
| 37 |
+
|
| 38 |
+
Use Python 3.10, later versions will might not be compatible with the dependencies.
|
| 39 |
+
|
| 40 |
+
```sh
|
| 41 |
+
python -m venv .venv
|
| 42 |
+
|
| 43 |
+
# In cmd.exe
|
| 44 |
+
.venv\Scripts\activate.bat
|
| 45 |
+
# In PowerShell
|
| 46 |
+
.venv\Scripts\Activate.ps1
|
| 47 |
+
# In Linux or MacOS
|
| 48 |
+
source .venv/bin/activate
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### 3. Install dependencies
|
| 52 |
+
|
| 53 |
+
- Python 3.10
|
| 54 |
+
- [Git](https://git-scm.com/downloads)
|
| 55 |
+
- [ffmpeg](http://www.ffmpeg.org/)
|
| 56 |
+
- [PyTorch](https://pytorch.org/get-started/locally/): Make sure to follow the Get Started guide so you install `torch` and `torchaudio` with GPU support.
|
| 57 |
+
|
| 58 |
+
- and the remaining Python dependencies:
|
| 59 |
+
|
| 60 |
+
```sh
|
| 61 |
+
pip install -r requirements.txt
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
## Web GUI (Recommended)
|
| 65 |
+
|
| 66 |
+
For a more user-friendly experience, consider using the Web UI. It provides a graphical interface to configure generation parameters, start the process, and monitor the output.
|
| 67 |
+
|
| 68 |
+
### Launch the GUI
|
| 69 |
+
|
| 70 |
+
Navigate to the cloned `Mapperatorinator` directory in your terminal and run:
|
| 71 |
+
|
| 72 |
+
```sh
|
| 73 |
+
python web-ui.py
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
This will start a local web server and automatically open the UI in a new window.
|
| 77 |
+
|
| 78 |
+
### Using the GUI
|
| 79 |
+
|
| 80 |
+
- **Configure:** Set input/output paths using the form fields and "Browse" buttons. Adjust generation parameters like gamemode, difficulty, style (year, mapper ID, descriptors), timing, specific features (hitsounds, super timing), and more, mirroring the command-line options. (Note: If you provide a `beatmap_path`, the UI will automatically determine the `audio_path` and `output_path` from it, so you can leave those fields blank)
|
| 81 |
+
- **Start:** Click the "Start Inference" button to begin the beatmap generation.
|
| 82 |
+
- **Cancel:** You can stop the ongoing process using the "Cancel Inference" button.
|
| 83 |
+
- **Open Output:** Once finished, use the "Open Output Folder" button for quick access to the generated files.
|
| 84 |
+
|
| 85 |
+
The Web UI acts as a convenient wrapper around the `inference.py` script. For advanced options or troubleshooting, refer to the command-line instructions.
|
| 86 |
+
|
| 87 |
+

|
| 88 |
+
|
| 89 |
+
## Command-Line Inference
|
| 90 |
+
|
| 91 |
+
For users who prefer the command line or need access to advanced configurations, follow the steps below. **Note:** For a simpler graphical interface, please see the [Web UI (Recommended)](#web-ui-recommended) section above.
|
| 92 |
+
|
| 93 |
+
Run `inference.py` and pass in some arguments to generate beatmaps. For this use [Hydra override syntax](https://hydra.cc/docs/advanced/override_grammar/basic/). See `configs/inference_v29.yaml` for all available parameters.
|
| 94 |
+
```
|
| 95 |
+
python inference.py \
|
| 96 |
+
audio_path [Path to input audio] \
|
| 97 |
+
output_path [Path to output directory] \
|
| 98 |
+
beatmap_path [Path to .osu file to autofill metadata, and output_path, or use as reference] \
|
| 99 |
+
|
| 100 |
+
gamemode [Game mode to generate 0=std, 1=taiko, 2=ctb, 3=mania] \
|
| 101 |
+
difficulty [Difficulty star rating to generate] \
|
| 102 |
+
mapper_id [Mapper user ID for style] \
|
| 103 |
+
year [Upload year to simulate] \
|
| 104 |
+
hitsounded [Whether to add hitsounds] \
|
| 105 |
+
slider_multiplier [Slider velocity multiplier] \
|
| 106 |
+
circle_size [Circle size] \
|
| 107 |
+
keycount [Key count for mania] \
|
| 108 |
+
hold_note_ratio [Hold note ratio for mania 0-1] \
|
| 109 |
+
scroll_speed_ratio [Scroll speed ratio for mania and ctb 0-1] \
|
| 110 |
+
descriptors [List of beatmap user tags for style] \
|
| 111 |
+
negative_descriptors [List of beatmap user tags for classifier-free guidance] \
|
| 112 |
+
|
| 113 |
+
add_to_beatmap [Whether to add generated content to the reference beatmap instead of making a new beatmap] \
|
| 114 |
+
start_time [Generation start time in milliseconds] \
|
| 115 |
+
end_time [Generation end time in milliseconds] \
|
| 116 |
+
in_context [List of additional context to provide to the model [NONE,TIMING,KIAI,MAP,GD,NO_HS]] \
|
| 117 |
+
output_type [List of content types to generate] \
|
| 118 |
+
cfg_scale [Scale of the classifier-free guidance] \
|
| 119 |
+
super_timing [Whether to use slow accurate variable BPM timing generator] \
|
| 120 |
+
seed [Random seed for generation] \
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
Example:
|
| 124 |
+
```
|
| 125 |
+
python inference.py beatmap_path="'C:\Users\USER\AppData\Local\osu!\Songs\1 Kenji Ninuma - DISCO PRINCE\Kenji Ninuma - DISCOPRINCE (peppy) [Normal].osu'" gamemode=0 difficulty=5.5 year=2023 descriptors="['jump aim','clean']" in_context=[TIMING,KIAI]
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
## Interactive CLI
|
| 129 |
+
For those who prefer a terminal-based workflow but want a guided setup, the interactive CLI script is an excellent alternative to the Web UI.
|
| 130 |
+
|
| 131 |
+
### Launch the CLI
|
| 132 |
+
Navigate to the cloned directory. You may need to make the script executable first.
|
| 133 |
+
|
| 134 |
+
```sh
|
| 135 |
+
# Make the script executable (only needs to be done once)
|
| 136 |
+
chmod +x cli_inference.sh
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
```sh
|
| 140 |
+
# Run the script
|
| 141 |
+
./cli_inference.sh
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
### Using the CLI
|
| 145 |
+
The script will walk you through a series of prompts to configure all generation parameters, just like the Web UI.
|
| 146 |
+
|
| 147 |
+
It uses a color-coded interface for clarity.
|
| 148 |
+
It provides an advanced multi-select menu for choosing style descriptors using your arrow keys and spacebar.
|
| 149 |
+
After you've answered all the questions, it will display the final command for your review.
|
| 150 |
+
You can then confirm to execute it directly or cancel and copy the command for manual use.
|
| 151 |
+
|
| 152 |
+
## Generation Tips
|
| 153 |
+
|
| 154 |
+
- You can edit `configs/inference_v29.yaml` and add your arguments there instead of typing them in the terminal every time.
|
| 155 |
+
- All available descriptors can be found [here](https://osu.ppy.sh/wiki/en/Beatmap/Beatmap_tags).
|
| 156 |
+
- Always provide a year argument between 2007 and 2023. If you leave it unknown, the model might generate with an inconsistent style.
|
| 157 |
+
- Always provide a difficulty argument. If you leave it unknown, the model might generate with an inconsistent difficulty.
|
| 158 |
+
- Increase the `cfg_scale` parameter to increase the effectiveness of the `mapper_id` and `descriptors` arguments.
|
| 159 |
+
- You can use the `negative_descriptors` argument to guide the model away from certain styles. This only works when `cfg_scale > 1`. Make sure the number of negative descriptors is equal to the number of descriptors.
|
| 160 |
+
- If your song style and desired beatmap style don't match well, the model might not follow your directions. For example, its hard to generate a high SR, high SV beatmap for a calm song.
|
| 161 |
+
- If you already have timing and kiai times done for a song, then you can give this to the model to greatly increase inference speed and accuracy: Use the `beatmap_path` and `in_context=[TIMING,KIAI]` arguments.
|
| 162 |
+
- To remap just a part of your beatmap, use the `beatmap_path`, `start_time`, `end_time`, and `add_to_beatmap=true` arguments.
|
| 163 |
+
- To generate a guest difficulty for a beatmap, use the `beatmap_path` and `in_context=[GD,TIMING,KIAI]` arguments.
|
| 164 |
+
- To generate hitsounds for a beatmap, use the `beatmap_path` and `in_context=[NO_HS,TIMING,KIAI]` arguments.
|
| 165 |
+
- To generate only timing for a song, use the `super_timing=true` and `output_type=[TIMING]` arguments.
|
| 166 |
+
|
| 167 |
+
## MaiMod: The AI-driven Modding Tool
|
| 168 |
+
|
| 169 |
+
MaiMod is a modding tool for osu! beatmaps that uses Mapperatorinator predictions to find potential faults and inconsistencies which can't be detected by other automatic modding tools like [Mapset Verifier](https://github.com/Naxesss/MapsetVerifier).
|
| 170 |
+
It can detect issues like:
|
| 171 |
+
- Incorrect snapping or rhythmic patterns
|
| 172 |
+
- Inaccurate timing points
|
| 173 |
+
- Inconsistent hit object positions or new combo placements
|
| 174 |
+
- Weird slider shapes
|
| 175 |
+
- Inconsistent hitsounds or volumes
|
| 176 |
+
|
| 177 |
+
You can try MaiMod [here](https://colab.research.google.com/github/OliBomby/Mapperatorinator/blob/main/colab/mai_mod_inference.ipynb), or run it locally:
|
| 178 |
+
To run MaiMod locally, you'll need to install Mapperatorinator. Then, run the `mai_mod.py` script, specifying your beatmap's path with the `beatmap_path` argument.
|
| 179 |
+
```sh
|
| 180 |
+
python mai_mod.py beatmap_path="'C:\Users\USER\AppData\Local\osu!\Songs\1 Kenji Ninuma - DISCO PRINCE\Kenji Ninuma - DISCOPRINCE (peppy) [Normal].osu'"
|
| 181 |
+
```
|
| 182 |
+
This will print the modding suggestions to the console, which you can then apply to your beatmap manually.
|
| 183 |
+
Suggestions are ordered chronologically and grouped into categories.
|
| 184 |
+
The first value in the circle indicates the 'surprisal' which is a measure of how unexpected the model found the issue to be, so you can prioritize the most important issues.
|
| 185 |
+
|
| 186 |
+
The model can make mistakes, especially on low surprisal issues, so always double-check the suggestions before applying them to your beatmap.
|
| 187 |
+
The main goal is to help you narrow down the search space for potential issues, so you don't have to manually check every single hit object in your beatmap.
|
| 188 |
+
|
| 189 |
+
### MaiMod GUI
|
| 190 |
+
To run the MaiMod Web UI, you'll need to install Mapperatorinator.
|
| 191 |
+
Then, run the `mai_mod_ui.py` script. This will start a local web server and automatically open the UI in a new window:
|
| 192 |
+
|
| 193 |
+
```sh
|
| 194 |
+
python mai_mod_ui.py
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
<img width="850" height="1019" alt="afbeelding" src="https://github.com/user-attachments/assets/67c03a43-a7bd-4265-a5b1-5e4d62aca1fa" />
|
| 198 |
+
|
| 199 |
+
## Overview
|
| 200 |
+
|
| 201 |
+
### Tokenization
|
| 202 |
+
|
| 203 |
+
Mapperatorinator converts osu! beatmaps into an intermediate event representation that can be directly converted to and from tokens.
|
| 204 |
+
It includes hit objects, hitsounds, slider velocities, new combos, timing points, kiai times, and taiko/mania scroll speeds.
|
| 205 |
+
|
| 206 |
+
Here is a small examle of the tokenization process:
|
| 207 |
+
|
| 208 |
+

|
| 209 |
+
|
| 210 |
+
To save on vocabulary size, time events are quantized to 10ms intervals and position coordinates are quantized to 32 pixel grid points.
|
| 211 |
+
|
| 212 |
+
### Model architecture
|
| 213 |
+
The model is basically a wrapper around the [HF Transformers Whisper](https://huggingface.co/docs/transformers/en/model_doc/whisper#transformers.WhisperForConditionalGeneration) model, with custom input embeddings and loss function.
|
| 214 |
+
Model size amounts to 219M parameters.
|
| 215 |
+
This model was found to be faster and more accurate than T5 for this task.
|
| 216 |
+
|
| 217 |
+
The high-level overview of the model's input-output is as follows:
|
| 218 |
+
|
| 219 |
+

|
| 220 |
+
|
| 221 |
+
The model uses Mel spectrogram frames as encoder input, with one frame per input position. The model decoder output at each step is a softmax distribution over a discrete, predefined, vocabulary of events. Outputs are sparse, events are only needed when a hit-object occurs, instead of annotating every single audio frame.
|
| 222 |
+
|
| 223 |
+
### Multitask training format
|
| 224 |
+
|
| 225 |
+

|
| 226 |
+
|
| 227 |
+
Before the SOS token are additional tokens that facilitate conditional generation. These tokens include the gamemode, difficulty, mapper ID, year, and other metadata.
|
| 228 |
+
During training, these tokens do not have accompanying labels, so they are never output by the model.
|
| 229 |
+
Also during training there is a random chance that a metadata token gets replaced by an 'unknown' token, so during inference we can use these 'unknown' tokens to reduce the amount of metadata we have to give to the model.
|
| 230 |
+
|
| 231 |
+
### Seamless long generation
|
| 232 |
+
|
| 233 |
+
The context length of the model is 8.192 seconds long. This is obviously not enough to generate a full beatmap, so we have to split the song into multiple windows and generate the beatmap in small parts.
|
| 234 |
+
To make sure that the generated beatmap does not have noticeable seams in between windows, we use a 90% overlap and generate the windows sequentially.
|
| 235 |
+
Each generation window except the first starts with the decoder pre-filled up to 50% of the generation window with tokens from the previous windows.
|
| 236 |
+
We use a logit processor to make sure that the model can't generate time tokens that are in the first 50% of the generation window.
|
| 237 |
+
Additionally, the last 40% of the generation window is reserved for the next window. Any generated time tokens in that range are treated as EOS tokens.
|
| 238 |
+
This ensures that each generated token is conditioned on at least 4 seconds of previous tokens and 3.3 seconds of future audio to anticipate.
|
| 239 |
+
|
| 240 |
+
To prevent offset drifting during long generation, random offsets have been added to time events in the decoder during training.
|
| 241 |
+
This forces it to correct timing errors by listening to the onsets in the audio instead, and results in a consistently accurate offset.
|
| 242 |
+
|
| 243 |
+
### Refined coordinates with diffusion
|
| 244 |
+
|
| 245 |
+
Position coordinates generated by the decoder are quantized to 32 pixel grid points, so afterward we use diffusion to denoise the coordinates to the final positions.
|
| 246 |
+
For this we trained a modified version of [osu-diffusion](https://github.com/OliBomby/osu-diffusion) that is specialized to only the last 10% of the noise schedule, and accepts the more advanced metadata tokens that Mapperatorinator uses for conditional generation.
|
| 247 |
+
|
| 248 |
+
Since the Mapperatorinator model outputs the SV of sliders, the required length of the slider is fixed regardless of the shape of the control point path.
|
| 249 |
+
Therefore, we try to guide the diffusion process to create coordinates that fit the required slider lengths.
|
| 250 |
+
We do this by recalculating the slider end positions after every step of the diffusion process based on the required length and the current control point path.
|
| 251 |
+
This means that the diffusion process does not have direct control over the slider end positions, but it can still influence them by changing the control point path.
|
| 252 |
+
|
| 253 |
+
### Post-processing
|
| 254 |
+
|
| 255 |
+
Mapperatorinator does some extra post-processing to improve the quality of the generated beatmap:
|
| 256 |
+
|
| 257 |
+
- Refine position coordinates with diffusion.
|
| 258 |
+
- Resnap time events to the nearest tick using the snap divisors generated by the model.
|
| 259 |
+
- Snap near-perfect positional overlaps.
|
| 260 |
+
- Convert mania column events to X coordinates.
|
| 261 |
+
- Generate slider paths for taiko drumrolls.
|
| 262 |
+
- Fix big discrepancies in required slider length and control point path length.
|
| 263 |
+
|
| 264 |
+
### Super timing generator
|
| 265 |
+
|
| 266 |
+
Super timing generator is an algorithm that improves the precision and accuracy of generated timing by infering timing for the whole song 20 times and averaging the results.
|
| 267 |
+
This is useful for songs with variable BPM, or songs with BPM changes. The result is almost perfect with only sometimes a section that needs manual adjustment.
|
| 268 |
+
|
| 269 |
+
## Training
|
| 270 |
+
|
| 271 |
+
The instruction below creates a training environment on your local machine.
|
| 272 |
+
|
| 273 |
+
### 1. Clone the repository
|
| 274 |
+
|
| 275 |
+
```sh
|
| 276 |
+
git clone https://github.com/OliBomby/Mapperatorinator.git
|
| 277 |
+
cd Mapperatorinator
|
| 278 |
+
```
|
| 279 |
+
|
| 280 |
+
### 2. Create dataset
|
| 281 |
+
|
| 282 |
+
Create your own dataset using the [Mapperator console app](https://github.com/mappingtools/Mapperator/blob/master/README.md#create-a-high-quality-dataset). It requires an [osu! OAuth client token](https://osu.ppy.sh/home/account/edit) to verify beatmaps and get additional metadata. Place the dataset in a `datasets` directory next to the `Mapperatorinator` directory.
|
| 283 |
+
|
| 284 |
+
```sh
|
| 285 |
+
Mapperator.ConsoleApp.exe dataset2 -t "/Mapperatorinator/datasets/beatmap_descriptors.csv" -i "path/to/osz/files" -o "/datasets/cool_dataset"
|
| 286 |
+
```
|
| 287 |
+
|
| 288 |
+
### 3. Create docker container
|
| 289 |
+
Training in your venv is also possible, but we recommend using Docker on WSL for better performance.
|
| 290 |
+
```sh
|
| 291 |
+
docker compose up -d --force-recreate
|
| 292 |
+
docker attach mapperatorinator_space
|
| 293 |
+
```
|
| 294 |
+
|
| 295 |
+
### 4. Configure parameters and begin training
|
| 296 |
+
|
| 297 |
+
All configurations are located in `./configs/osut5/train.yaml`. Begin training by calling `osuT5/train.py`.
|
| 298 |
+
|
| 299 |
+
```sh
|
| 300 |
+
python osuT5/train.py -cn train_v29 train_dataset_path="/workspace/datasets/cool_dataset" test_dataset_path="/workspace/datasets/cool_dataset" train_dataset_end=90 test_dataset_start=90 test_dataset_end=100
|
| 301 |
+
```
|
| 302 |
+
|
| 303 |
+
## See also
|
| 304 |
+
- [Mapper Classifier](./classifier/README.md)
|
| 305 |
+
- [RComplexion](./rcomplexion/README.md)
|
| 306 |
+
|
| 307 |
+
## Credits
|
| 308 |
+
|
| 309 |
+
Special thanks to:
|
| 310 |
+
1. The authors of [osuT5](https://github.com/gyataro/osuT5) for their training code.
|
| 311 |
+
2. Hugging Face team for their [tools](https://huggingface.co/docs/transformers/index).
|
| 312 |
+
3. [Jason Won](https://github.com/jaswon) and [Richard Nagyfi](https://github.com/sedthh) for bouncing ideas.
|
| 313 |
+
4. [Marvin](https://github.com/minetoblend) for donating training credits.
|
| 314 |
+
5. The osu! community for the beatmaps.
|
| 315 |
+
|
| 316 |
+
## Related works
|
| 317 |
+
|
| 318 |
+
1. [osu! Beatmap Generator](https://github.com/Syps/osu_beatmap_generator) by Syps (Nick Sypteras)
|
| 319 |
+
2. [osumapper](https://github.com/kotritrona/osumapper) by kotritrona, jyvden, Yoyolick (Ryan Zmuda)
|
| 320 |
+
3. [osu-diffusion](https://github.com/OliBomby/osu-diffusion) by OliBomby (Olivier Schipper), NiceAesth (Andrei Baciu)
|
| 321 |
+
4. [osuT5](https://github.com/gyataro/osuT5) by gyataro (Xiwen Teoh)
|
| 322 |
+
5. [Beat Learning](https://github.com/sedthh/BeatLearning) by sedthh (Richard Nagyfi)
|
| 323 |
+
6. [osu!dreamer](https://github.com/jaswon/osu-dreamer) by jaswon (Jason Won)
|
audit_all_configs.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Comprehensive Config Audit Script for BeatHeritage
|
| 4 |
+
Checks all config files against their corresponding dataclass definitions
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import yaml
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from dataclasses import fields
|
| 11 |
+
from typing import Dict, List, Set, Any
|
| 12 |
+
|
| 13 |
+
# Import all config classes
|
| 14 |
+
from config import InferenceConfig, FidConfig, MaiModConfig
|
| 15 |
+
from osuT5.osuT5.config import TrainConfig, DataConfig, DataloaderConfig, OptimizerConfig
|
| 16 |
+
from osu_diffusion.config import DiffusionTrainConfig
|
| 17 |
+
|
| 18 |
+
def get_config_fields(config_class) -> Set[str]:
|
| 19 |
+
"""Get all field names from a dataclass"""
|
| 20 |
+
return {field.name for field in fields(config_class)}
|
| 21 |
+
|
| 22 |
+
def get_yaml_keys(yaml_path: str, prefix: str = "") -> Set[str]:
|
| 23 |
+
"""Get all keys from a YAML file, including nested keys with dot notation"""
|
| 24 |
+
keys = set()
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
with open(yaml_path, 'r') as f:
|
| 28 |
+
data = yaml.safe_load(f)
|
| 29 |
+
|
| 30 |
+
def extract_keys(obj, parent_key=""):
|
| 31 |
+
if isinstance(obj, dict):
|
| 32 |
+
for key, value in obj.items():
|
| 33 |
+
if key == 'defaults': # Skip Hydra defaults
|
| 34 |
+
continue
|
| 35 |
+
|
| 36 |
+
full_key = f"{parent_key}.{key}" if parent_key else key
|
| 37 |
+
keys.add(full_key)
|
| 38 |
+
|
| 39 |
+
if isinstance(value, dict):
|
| 40 |
+
extract_keys(value, full_key)
|
| 41 |
+
elif isinstance(value, list) and value and isinstance(value[0], dict):
|
| 42 |
+
# Handle list of dicts
|
| 43 |
+
extract_keys(value[0], full_key)
|
| 44 |
+
|
| 45 |
+
extract_keys(data)
|
| 46 |
+
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f"Error reading {yaml_path}: {e}")
|
| 49 |
+
|
| 50 |
+
return keys
|
| 51 |
+
|
| 52 |
+
def audit_config_mapping(config_path: str, config_class, config_name: str):
|
| 53 |
+
"""Audit a specific config file against its dataclass"""
|
| 54 |
+
print(f"\n[AUDIT] {config_name}: {config_path}")
|
| 55 |
+
|
| 56 |
+
if not os.path.exists(config_path):
|
| 57 |
+
print(f"[ERROR] Config file not found: {config_path}")
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
# Get fields from dataclass
|
| 61 |
+
class_fields = get_config_fields(config_class)
|
| 62 |
+
|
| 63 |
+
# Get keys from YAML
|
| 64 |
+
yaml_keys = get_yaml_keys(config_path)
|
| 65 |
+
|
| 66 |
+
# Find mismatches
|
| 67 |
+
missing_in_class = yaml_keys - class_fields
|
| 68 |
+
missing_in_config = class_fields - yaml_keys
|
| 69 |
+
|
| 70 |
+
# Filter out nested keys for top-level check
|
| 71 |
+
top_level_yaml = {key.split('.')[0] for key in yaml_keys}
|
| 72 |
+
top_level_missing = top_level_yaml - class_fields
|
| 73 |
+
|
| 74 |
+
print(f"[SUMMARY]:")
|
| 75 |
+
print(f" - Dataclass fields: {len(class_fields)}")
|
| 76 |
+
print(f" - YAML keys (all): {len(yaml_keys)}")
|
| 77 |
+
print(f" - YAML keys (top-level): {len(top_level_yaml)}")
|
| 78 |
+
|
| 79 |
+
if top_level_missing:
|
| 80 |
+
print(f"[MISSING] Keys in YAML but missing in dataclass:")
|
| 81 |
+
for key in sorted(top_level_missing):
|
| 82 |
+
related_keys = [k for k in yaml_keys if k.startswith(key)]
|
| 83 |
+
print(f" - {key} (related: {len(related_keys)} keys)")
|
| 84 |
+
if len(related_keys) <= 5: # Show details for small sections
|
| 85 |
+
for rkey in sorted(related_keys)[:5]:
|
| 86 |
+
print(f" * {rkey}")
|
| 87 |
+
else:
|
| 88 |
+
print(f" * ... and {len(related_keys)-3} more keys")
|
| 89 |
+
|
| 90 |
+
if missing_in_config:
|
| 91 |
+
optional_missing = missing_in_config & {'hydra', 'train', 'diffusion'} # Usually optional
|
| 92 |
+
real_missing = missing_in_config - optional_missing
|
| 93 |
+
if real_missing:
|
| 94 |
+
print(f"[WARNING] Fields in dataclass but missing in YAML:")
|
| 95 |
+
for key in sorted(real_missing):
|
| 96 |
+
print(f" - {key}")
|
| 97 |
+
|
| 98 |
+
return {
|
| 99 |
+
'missing_in_class': top_level_missing,
|
| 100 |
+
'missing_in_config': missing_in_config,
|
| 101 |
+
'all_yaml_keys': yaml_keys,
|
| 102 |
+
'class_fields': class_fields
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
def main():
|
| 106 |
+
"""Run comprehensive config audit"""
|
| 107 |
+
print("BeatHeritage Config Audit - Finding ALL Mismatches")
|
| 108 |
+
print("=" * 60)
|
| 109 |
+
|
| 110 |
+
# Define config mappings
|
| 111 |
+
config_mappings = [
|
| 112 |
+
# Inference configs
|
| 113 |
+
("configs/inference/beatheritage_v1.yaml", InferenceConfig, "Inference (BeatHeritage V1)"),
|
| 114 |
+
("configs/inference/default.yaml", InferenceConfig, "Inference (Default)"),
|
| 115 |
+
|
| 116 |
+
# Training configs
|
| 117 |
+
("configs/train/beatheritage_v1.yaml", TrainConfig, "Training (BeatHeritage V1)"),
|
| 118 |
+
("configs/train/default.yaml", TrainConfig, "Training (Default)"),
|
| 119 |
+
|
| 120 |
+
# Diffusion configs
|
| 121 |
+
("configs/diffusion/v1.yaml", DiffusionTrainConfig, "Diffusion (V1)"),
|
| 122 |
+
]
|
| 123 |
+
|
| 124 |
+
all_issues = {}
|
| 125 |
+
|
| 126 |
+
for config_path, config_class, name in config_mappings:
|
| 127 |
+
issues = audit_config_mapping(config_path, config_class, name)
|
| 128 |
+
if issues and issues['missing_in_class']:
|
| 129 |
+
all_issues[name] = issues
|
| 130 |
+
|
| 131 |
+
# Summary report
|
| 132 |
+
print(f"\nAUDIT SUMMARY")
|
| 133 |
+
print("=" * 60)
|
| 134 |
+
|
| 135 |
+
if not all_issues:
|
| 136 |
+
print("All configs are aligned with their dataclasses!")
|
| 137 |
+
return
|
| 138 |
+
|
| 139 |
+
print(f"Found issues in {len(all_issues)} config(s):")
|
| 140 |
+
|
| 141 |
+
for config_name, issues in all_issues.items():
|
| 142 |
+
print(f"\n{config_name}:")
|
| 143 |
+
for key in sorted(issues['missing_in_class']):
|
| 144 |
+
print(f" - Missing field: {key}")
|
| 145 |
+
|
| 146 |
+
# Generate fix suggestions
|
| 147 |
+
print(f"\nSUGGESTED FIXES")
|
| 148 |
+
print("=" * 60)
|
| 149 |
+
|
| 150 |
+
for config_name, issues in all_issues.items():
|
| 151 |
+
if 'Inference' in config_name:
|
| 152 |
+
print(f"\nFor InferenceConfig class:")
|
| 153 |
+
for key in sorted(issues['missing_in_class']):
|
| 154 |
+
print(f" + {key}: <appropriate_type> = <default_value>")
|
| 155 |
+
|
| 156 |
+
if __name__ == "__main__":
|
| 157 |
+
main()
|
beatheritage_postprocessor.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BeatHeritage V1 Custom Postprocessor
|
| 3 |
+
Enhanced postprocessing for improved beatmap quality
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import List, Tuple, Dict, Optional
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
from osuT5.osuT5.inference.postprocessor import Postprocessor, BeatmapConfig
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class BeatHeritageConfig(BeatmapConfig):
|
| 18 |
+
"""Enhanced config for BeatHeritage V1 postprocessing"""
|
| 19 |
+
# Quality control parameters
|
| 20 |
+
min_distance_threshold: float = 20.0
|
| 21 |
+
max_overlap_ratio: float = 0.15
|
| 22 |
+
enable_auto_correction: bool = True
|
| 23 |
+
enable_flow_optimization: bool = True
|
| 24 |
+
|
| 25 |
+
# Pattern enhancement
|
| 26 |
+
enable_pattern_variety: bool = True
|
| 27 |
+
pattern_complexity_target: float = 0.7
|
| 28 |
+
|
| 29 |
+
# Difficulty scaling
|
| 30 |
+
enable_difficulty_scaling: bool = True
|
| 31 |
+
difficulty_variance_threshold: float = 0.3
|
| 32 |
+
|
| 33 |
+
# Style preservation
|
| 34 |
+
enable_style_preservation: bool = True
|
| 35 |
+
style_consistency_weight: float = 0.8
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BeatHeritagePostprocessor(Postprocessor):
|
| 39 |
+
"""Enhanced postprocessor for BeatHeritage V1"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, config: BeatHeritageConfig):
|
| 42 |
+
super().__init__(config)
|
| 43 |
+
self.config = config
|
| 44 |
+
self.flow_optimizer = FlowOptimizer(config)
|
| 45 |
+
self.pattern_enhancer = PatternEnhancer(config)
|
| 46 |
+
self.quality_controller = QualityController(config)
|
| 47 |
+
|
| 48 |
+
def postprocess(self, beatmap_data: Dict) -> Dict:
|
| 49 |
+
"""
|
| 50 |
+
Enhanced postprocessing pipeline for BeatHeritage V1
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
beatmap_data: Raw beatmap data from model
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Processed beatmap data with enhancements
|
| 57 |
+
"""
|
| 58 |
+
# Base postprocessing
|
| 59 |
+
beatmap_data = super().postprocess(beatmap_data)
|
| 60 |
+
|
| 61 |
+
# Quality control
|
| 62 |
+
if self.config.enable_auto_correction:
|
| 63 |
+
beatmap_data = self.quality_controller.fix_spacing_issues(beatmap_data)
|
| 64 |
+
beatmap_data = self.quality_controller.fix_overlaps(beatmap_data)
|
| 65 |
+
|
| 66 |
+
# Flow optimization
|
| 67 |
+
if self.config.enable_flow_optimization:
|
| 68 |
+
beatmap_data = self.flow_optimizer.optimize_flow(beatmap_data)
|
| 69 |
+
|
| 70 |
+
# Pattern enhancement
|
| 71 |
+
if self.config.enable_pattern_variety:
|
| 72 |
+
beatmap_data = self.pattern_enhancer.enhance_patterns(beatmap_data)
|
| 73 |
+
|
| 74 |
+
# Difficulty scaling
|
| 75 |
+
if self.config.enable_difficulty_scaling:
|
| 76 |
+
beatmap_data = self._scale_difficulty(beatmap_data)
|
| 77 |
+
|
| 78 |
+
# Style preservation
|
| 79 |
+
if self.config.enable_style_preservation:
|
| 80 |
+
beatmap_data = self._preserve_style(beatmap_data)
|
| 81 |
+
|
| 82 |
+
return beatmap_data
|
| 83 |
+
|
| 84 |
+
def _scale_difficulty(self, beatmap_data: Dict) -> Dict:
|
| 85 |
+
"""Scale difficulty to match target star rating"""
|
| 86 |
+
target_difficulty = self.config.difficulty
|
| 87 |
+
if target_difficulty is None:
|
| 88 |
+
return beatmap_data
|
| 89 |
+
|
| 90 |
+
current_difficulty = self._calculate_difficulty(beatmap_data)
|
| 91 |
+
scale_factor = target_difficulty / max(current_difficulty, 0.1)
|
| 92 |
+
|
| 93 |
+
# Adjust spacing and timing based on scale factor
|
| 94 |
+
if 'hit_objects' in beatmap_data:
|
| 95 |
+
for obj in beatmap_data['hit_objects']:
|
| 96 |
+
if 'distance' in obj:
|
| 97 |
+
obj['distance'] *= scale_factor
|
| 98 |
+
|
| 99 |
+
logger.info(f"Scaled difficulty from {current_difficulty:.2f} to {target_difficulty:.2f}")
|
| 100 |
+
return beatmap_data
|
| 101 |
+
|
| 102 |
+
def _preserve_style(self, beatmap_data: Dict) -> Dict:
|
| 103 |
+
"""Preserve mapping style consistency"""
|
| 104 |
+
# Analyze style characteristics
|
| 105 |
+
style_features = self._extract_style_features(beatmap_data)
|
| 106 |
+
|
| 107 |
+
# Apply style consistency
|
| 108 |
+
consistency_weight = self.config.style_consistency_weight
|
| 109 |
+
|
| 110 |
+
if 'hit_objects' in beatmap_data:
|
| 111 |
+
for i, obj in enumerate(beatmap_data['hit_objects']):
|
| 112 |
+
if i > 0:
|
| 113 |
+
# Maintain consistent spacing patterns
|
| 114 |
+
prev_obj = beatmap_data['hit_objects'][i-1]
|
| 115 |
+
expected_distance = style_features.get('avg_distance', 100)
|
| 116 |
+
|
| 117 |
+
if 'position' in obj and 'position' in prev_obj:
|
| 118 |
+
current_distance = self._calculate_distance(
|
| 119 |
+
obj['position'], prev_obj['position']
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Blend current with expected based on consistency weight
|
| 123 |
+
adjusted_distance = (
|
| 124 |
+
current_distance * (1 - consistency_weight) +
|
| 125 |
+
expected_distance * consistency_weight
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Adjust position to match distance
|
| 129 |
+
obj['position'] = self._adjust_position(
|
| 130 |
+
prev_obj['position'],
|
| 131 |
+
obj['position'],
|
| 132 |
+
adjusted_distance
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
return beatmap_data
|
| 136 |
+
|
| 137 |
+
def _calculate_difficulty(self, beatmap_data: Dict) -> float:
|
| 138 |
+
"""Calculate approximate star rating"""
|
| 139 |
+
# Simplified difficulty calculation
|
| 140 |
+
num_objects = len(beatmap_data.get('hit_objects', []))
|
| 141 |
+
avg_spacing = self._calculate_avg_spacing(beatmap_data)
|
| 142 |
+
bpm = beatmap_data.get('bpm', 180)
|
| 143 |
+
|
| 144 |
+
# Simple formula (can be improved)
|
| 145 |
+
difficulty = (num_objects / 100) * (avg_spacing / 50) * (bpm / 180)
|
| 146 |
+
return min(max(difficulty, 0), 10) # Clamp to 0-10
|
| 147 |
+
|
| 148 |
+
def _extract_style_features(self, beatmap_data: Dict) -> Dict:
|
| 149 |
+
"""Extract style characteristics from beatmap"""
|
| 150 |
+
features = {}
|
| 151 |
+
|
| 152 |
+
if 'hit_objects' in beatmap_data:
|
| 153 |
+
distances = []
|
| 154 |
+
for i in range(1, len(beatmap_data['hit_objects'])):
|
| 155 |
+
if 'position' in beatmap_data['hit_objects'][i]:
|
| 156 |
+
dist = self._calculate_distance(
|
| 157 |
+
beatmap_data['hit_objects'][i-1].get('position', (256, 192)),
|
| 158 |
+
beatmap_data['hit_objects'][i]['position']
|
| 159 |
+
)
|
| 160 |
+
distances.append(dist)
|
| 161 |
+
|
| 162 |
+
if distances:
|
| 163 |
+
features['avg_distance'] = np.mean(distances)
|
| 164 |
+
features['distance_variance'] = np.var(distances)
|
| 165 |
+
|
| 166 |
+
return features
|
| 167 |
+
|
| 168 |
+
def _calculate_avg_spacing(self, beatmap_data: Dict) -> float:
|
| 169 |
+
"""Calculate average spacing between objects"""
|
| 170 |
+
distances = []
|
| 171 |
+
objects = beatmap_data.get('hit_objects', [])
|
| 172 |
+
|
| 173 |
+
for i in range(1, len(objects)):
|
| 174 |
+
if 'position' in objects[i] and 'position' in objects[i-1]:
|
| 175 |
+
dist = self._calculate_distance(
|
| 176 |
+
objects[i-1]['position'],
|
| 177 |
+
objects[i]['position']
|
| 178 |
+
)
|
| 179 |
+
distances.append(dist)
|
| 180 |
+
|
| 181 |
+
return np.mean(distances) if distances else 100
|
| 182 |
+
|
| 183 |
+
def _calculate_distance(self, pos1: Tuple[float, float],
|
| 184 |
+
pos2: Tuple[float, float]) -> float:
|
| 185 |
+
"""Calculate Euclidean distance between two positions"""
|
| 186 |
+
return np.sqrt((pos1[0] - pos2[0])**2 + (pos1[1] - pos2[1])**2)
|
| 187 |
+
|
| 188 |
+
def _adjust_position(self, from_pos: Tuple[float, float],
|
| 189 |
+
to_pos: Tuple[float, float],
|
| 190 |
+
target_distance: float) -> Tuple[float, float]:
|
| 191 |
+
"""Adjust position to achieve target distance"""
|
| 192 |
+
current_distance = self._calculate_distance(from_pos, to_pos)
|
| 193 |
+
if current_distance < 0.01: # Avoid division by zero
|
| 194 |
+
return to_pos
|
| 195 |
+
|
| 196 |
+
scale = target_distance / current_distance
|
| 197 |
+
dx = (to_pos[0] - from_pos[0]) * scale
|
| 198 |
+
dy = (to_pos[1] - from_pos[1]) * scale
|
| 199 |
+
|
| 200 |
+
# Keep within playfield bounds
|
| 201 |
+
new_x = max(0, min(512, from_pos[0] + dx))
|
| 202 |
+
new_y = max(0, min(384, from_pos[1] + dy))
|
| 203 |
+
|
| 204 |
+
return (new_x, new_y)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class FlowOptimizer:
|
| 208 |
+
"""Optimize flow patterns in beatmaps"""
|
| 209 |
+
|
| 210 |
+
def __init__(self, config: BeatHeritageConfig):
|
| 211 |
+
self.config = config
|
| 212 |
+
|
| 213 |
+
def optimize_flow(self, beatmap_data: Dict) -> Dict:
|
| 214 |
+
"""Optimize flow for better playability"""
|
| 215 |
+
if 'hit_objects' not in beatmap_data:
|
| 216 |
+
return beatmap_data
|
| 217 |
+
|
| 218 |
+
objects = beatmap_data['hit_objects']
|
| 219 |
+
optimized_objects = []
|
| 220 |
+
|
| 221 |
+
for i, obj in enumerate(objects):
|
| 222 |
+
if i >= 2 and 'position' in obj:
|
| 223 |
+
# Calculate flow angle
|
| 224 |
+
prev_angle = self._calculate_angle(
|
| 225 |
+
objects[i-2].get('position', (256, 192)),
|
| 226 |
+
objects[i-1].get('position', (256, 192))
|
| 227 |
+
)
|
| 228 |
+
current_angle = self._calculate_angle(
|
| 229 |
+
objects[i-1].get('position', (256, 192)),
|
| 230 |
+
obj['position']
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# Smooth sharp angles
|
| 234 |
+
angle_diff = abs(current_angle - prev_angle)
|
| 235 |
+
if angle_diff > 120: # Sharp angle threshold
|
| 236 |
+
# Adjust position for smoother flow
|
| 237 |
+
smoothed_angle = prev_angle + np.sign(current_angle - prev_angle) * 90
|
| 238 |
+
distance = self._calculate_distance(
|
| 239 |
+
objects[i-1]['position'],
|
| 240 |
+
obj['position']
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
new_x = objects[i-1]['position'][0] + distance * np.cos(np.radians(smoothed_angle))
|
| 244 |
+
new_y = objects[i-1]['position'][1] + distance * np.sin(np.radians(smoothed_angle))
|
| 245 |
+
|
| 246 |
+
obj['position'] = (
|
| 247 |
+
max(0, min(512, new_x)),
|
| 248 |
+
max(0, min(384, new_y))
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
optimized_objects.append(obj)
|
| 252 |
+
|
| 253 |
+
beatmap_data['hit_objects'] = optimized_objects
|
| 254 |
+
return beatmap_data
|
| 255 |
+
|
| 256 |
+
def _calculate_angle(self, pos1: Tuple[float, float],
|
| 257 |
+
pos2: Tuple[float, float]) -> float:
|
| 258 |
+
"""Calculate angle between two positions in degrees"""
|
| 259 |
+
return np.degrees(np.arctan2(pos2[1] - pos1[1], pos2[0] - pos1[0]))
|
| 260 |
+
|
| 261 |
+
def _calculate_distance(self, pos1: Tuple[float, float],
|
| 262 |
+
pos2: Tuple[float, float]) -> float:
|
| 263 |
+
"""Calculate Euclidean distance"""
|
| 264 |
+
return np.sqrt((pos1[0] - pos2[0])**2 + (pos1[1] - pos2[1])**2)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class PatternEnhancer:
|
| 268 |
+
"""Enhance pattern variety in beatmaps"""
|
| 269 |
+
|
| 270 |
+
def __init__(self, config: BeatHeritageConfig):
|
| 271 |
+
self.config = config
|
| 272 |
+
self.pattern_library = self._load_pattern_library()
|
| 273 |
+
|
| 274 |
+
def enhance_patterns(self, beatmap_data: Dict) -> Dict:
|
| 275 |
+
"""Enhance patterns for more variety"""
|
| 276 |
+
if 'hit_objects' not in beatmap_data:
|
| 277 |
+
return beatmap_data
|
| 278 |
+
|
| 279 |
+
# Detect repetitive patterns
|
| 280 |
+
repetitive_sections = self._detect_repetitive_patterns(beatmap_data)
|
| 281 |
+
|
| 282 |
+
# Replace with varied patterns
|
| 283 |
+
for section in repetitive_sections:
|
| 284 |
+
beatmap_data = self._vary_pattern(beatmap_data, section)
|
| 285 |
+
|
| 286 |
+
return beatmap_data
|
| 287 |
+
|
| 288 |
+
def _load_pattern_library(self) -> List[Dict]:
|
| 289 |
+
"""Load common mapping patterns"""
|
| 290 |
+
return [
|
| 291 |
+
{'name': 'triangle', 'positions': [(0, 0), (100, 0), (50, 86.6)]},
|
| 292 |
+
{'name': 'square', 'positions': [(0, 0), (100, 0), (100, 100), (0, 100)]},
|
| 293 |
+
{'name': 'star', 'positions': [(50, 0), (61, 35), (97, 35), (68, 57), (79, 91), (50, 70), (21, 91), (32, 57), (3, 35), (39, 35)]},
|
| 294 |
+
{'name': 'hexagon', 'positions': [(50, 0), (93, 25), (93, 75), (50, 100), (7, 75), (7, 25)]},
|
| 295 |
+
]
|
| 296 |
+
|
| 297 |
+
def _detect_repetitive_patterns(self, beatmap_data: Dict) -> List[Tuple[int, int]]:
|
| 298 |
+
"""Detect sections with repetitive patterns"""
|
| 299 |
+
repetitive_sections = []
|
| 300 |
+
objects = beatmap_data.get('hit_objects', [])
|
| 301 |
+
|
| 302 |
+
window_size = 8
|
| 303 |
+
for i in range(len(objects) - window_size * 2):
|
| 304 |
+
pattern1 = self._extract_pattern(objects[i:i+window_size])
|
| 305 |
+
pattern2 = self._extract_pattern(objects[i+window_size:i+window_size*2])
|
| 306 |
+
|
| 307 |
+
if self._patterns_similar(pattern1, pattern2):
|
| 308 |
+
repetitive_sections.append((i, i + window_size * 2))
|
| 309 |
+
|
| 310 |
+
return repetitive_sections
|
| 311 |
+
|
| 312 |
+
def _extract_pattern(self, objects: List[Dict]) -> List[Tuple[float, float]]:
|
| 313 |
+
"""Extract position pattern from objects"""
|
| 314 |
+
return [obj.get('position', (256, 192)) for obj in objects]
|
| 315 |
+
|
| 316 |
+
def _patterns_similar(self, pattern1: List, pattern2: List, threshold: float = 0.8) -> bool:
|
| 317 |
+
"""Check if two patterns are similar"""
|
| 318 |
+
if len(pattern1) != len(pattern2):
|
| 319 |
+
return False
|
| 320 |
+
|
| 321 |
+
distances = []
|
| 322 |
+
for pos1, pos2 in zip(pattern1, pattern2):
|
| 323 |
+
dist = np.sqrt((pos1[0] - pos2[0])**2 + (pos1[1] - pos2[1])**2)
|
| 324 |
+
distances.append(dist)
|
| 325 |
+
|
| 326 |
+
avg_distance = np.mean(distances)
|
| 327 |
+
return avg_distance < 50 # Threshold for similarity
|
| 328 |
+
|
| 329 |
+
def _vary_pattern(self, beatmap_data: Dict, section: Tuple[int, int]) -> Dict:
|
| 330 |
+
"""Apply variation to a pattern section"""
|
| 331 |
+
start, end = section
|
| 332 |
+
objects = beatmap_data['hit_objects']
|
| 333 |
+
|
| 334 |
+
# Select a random pattern from library
|
| 335 |
+
pattern = np.random.choice(self.pattern_library)
|
| 336 |
+
pattern_positions = pattern['positions']
|
| 337 |
+
|
| 338 |
+
# Apply pattern with scaling
|
| 339 |
+
section_length = end - start
|
| 340 |
+
for i in range(start, min(end, len(objects))):
|
| 341 |
+
if 'position' in objects[i]:
|
| 342 |
+
pattern_idx = (i - start) % len(pattern_positions)
|
| 343 |
+
base_pos = pattern_positions[pattern_idx]
|
| 344 |
+
|
| 345 |
+
# Scale and translate pattern
|
| 346 |
+
center = (256, 192)
|
| 347 |
+
scale = 2.0
|
| 348 |
+
|
| 349 |
+
new_x = center[0] + base_pos[0] * scale
|
| 350 |
+
new_y = center[1] + base_pos[1] * scale
|
| 351 |
+
|
| 352 |
+
objects[i]['position'] = (
|
| 353 |
+
max(0, min(512, new_x)),
|
| 354 |
+
max(0, min(384, new_y))
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
return beatmap_data
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class QualityController:
|
| 361 |
+
"""Control quality aspects of beatmaps"""
|
| 362 |
+
|
| 363 |
+
def __init__(self, config: BeatHeritageConfig):
|
| 364 |
+
self.config = config
|
| 365 |
+
|
| 366 |
+
def fix_spacing_issues(self, beatmap_data: Dict) -> Dict:
|
| 367 |
+
"""Fix objects that are too close together"""
|
| 368 |
+
if 'hit_objects' not in beatmap_data:
|
| 369 |
+
return beatmap_data
|
| 370 |
+
|
| 371 |
+
objects = beatmap_data['hit_objects']
|
| 372 |
+
min_distance = self.config.min_distance_threshold
|
| 373 |
+
|
| 374 |
+
for i in range(1, len(objects)):
|
| 375 |
+
if 'position' in objects[i] and 'position' in objects[i-1]:
|
| 376 |
+
distance = self._calculate_distance(
|
| 377 |
+
objects[i-1]['position'],
|
| 378 |
+
objects[i]['position']
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
if distance < min_distance:
|
| 382 |
+
# Move object to maintain minimum distance
|
| 383 |
+
direction = self._get_direction(
|
| 384 |
+
objects[i-1]['position'],
|
| 385 |
+
objects[i]['position']
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
objects[i]['position'] = self._move_position(
|
| 389 |
+
objects[i-1]['position'],
|
| 390 |
+
direction,
|
| 391 |
+
min_distance
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
return beatmap_data
|
| 395 |
+
|
| 396 |
+
def fix_overlaps(self, beatmap_data: Dict) -> Dict:
|
| 397 |
+
"""Fix overlapping sliders and circles"""
|
| 398 |
+
if 'hit_objects' not in beatmap_data:
|
| 399 |
+
return beatmap_data
|
| 400 |
+
|
| 401 |
+
objects = beatmap_data['hit_objects']
|
| 402 |
+
max_overlap = self.config.max_overlap_ratio
|
| 403 |
+
|
| 404 |
+
for i in range(len(objects)):
|
| 405 |
+
for j in range(i+1, min(i+10, len(objects))): # Check next 10 objects
|
| 406 |
+
if self._objects_overlap(objects[i], objects[j], max_overlap):
|
| 407 |
+
# Adjust position to reduce overlap
|
| 408 |
+
objects[j] = self._adjust_for_overlap(objects[i], objects[j])
|
| 409 |
+
|
| 410 |
+
return beatmap_data
|
| 411 |
+
|
| 412 |
+
def _calculate_distance(self, pos1: Tuple[float, float],
|
| 413 |
+
pos2: Tuple[float, float]) -> float:
|
| 414 |
+
"""Calculate Euclidean distance"""
|
| 415 |
+
return np.sqrt((pos1[0] - pos2[0])**2 + (pos1[1] - pos2[1])**2)
|
| 416 |
+
|
| 417 |
+
def _get_direction(self, from_pos: Tuple[float, float],
|
| 418 |
+
to_pos: Tuple[float, float]) -> Tuple[float, float]:
|
| 419 |
+
"""Get normalized direction vector"""
|
| 420 |
+
dx = to_pos[0] - from_pos[0]
|
| 421 |
+
dy = to_pos[1] - from_pos[1]
|
| 422 |
+
|
| 423 |
+
length = np.sqrt(dx**2 + dy**2)
|
| 424 |
+
if length < 0.01:
|
| 425 |
+
return (1, 0) # Default right direction
|
| 426 |
+
|
| 427 |
+
return (dx / length, dy / length)
|
| 428 |
+
|
| 429 |
+
def _move_position(self, from_pos: Tuple[float, float],
|
| 430 |
+
direction: Tuple[float, float],
|
| 431 |
+
distance: float) -> Tuple[float, float]:
|
| 432 |
+
"""Move position in direction by distance"""
|
| 433 |
+
new_x = from_pos[0] + direction[0] * distance
|
| 434 |
+
new_y = from_pos[1] + direction[1] * distance
|
| 435 |
+
|
| 436 |
+
# Keep within bounds
|
| 437 |
+
return (
|
| 438 |
+
max(0, min(512, new_x)),
|
| 439 |
+
max(0, min(384, new_y))
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
def _objects_overlap(self, obj1: Dict, obj2: Dict, threshold: float) -> bool:
|
| 443 |
+
"""Check if two objects overlap beyond threshold"""
|
| 444 |
+
if 'position' not in obj1 or 'position' not in obj2:
|
| 445 |
+
return False
|
| 446 |
+
|
| 447 |
+
distance = self._calculate_distance(obj1['position'], obj2['position'])
|
| 448 |
+
|
| 449 |
+
# Simple overlap check (can be improved for sliders)
|
| 450 |
+
radius = 30 # Approximate circle radius
|
| 451 |
+
overlap = max(0, 2 * radius - distance) / (2 * radius)
|
| 452 |
+
|
| 453 |
+
return overlap > threshold
|
| 454 |
+
|
| 455 |
+
def _adjust_for_overlap(self, obj1: Dict, obj2: Dict) -> Dict:
|
| 456 |
+
"""Adjust object position to reduce overlap"""
|
| 457 |
+
if 'position' not in obj1 or 'position' not in obj2:
|
| 458 |
+
return obj2
|
| 459 |
+
|
| 460 |
+
# Move obj2 away from obj1
|
| 461 |
+
direction = self._get_direction(obj1['position'], obj2['position'])
|
| 462 |
+
min_safe_distance = 60 # Minimum safe distance
|
| 463 |
+
|
| 464 |
+
obj2['position'] = self._move_position(
|
| 465 |
+
obj1['position'],
|
| 466 |
+
direction,
|
| 467 |
+
min_safe_distance
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
return obj2
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
# Export main postprocessor
|
| 474 |
+
__all__ = ['BeatHeritagePostprocessor', 'BeatHeritageConfig']
|
benchmark_comparison.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
BeatHeritage V1 vs Mapperatorinator V30 Benchmark Script
|
| 4 |
+
Compares performance, quality, and generation characteristics
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import time
|
| 10 |
+
import json
|
| 11 |
+
import argparse
|
| 12 |
+
import subprocess
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pandas as pd
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Dict, List, Tuple, Optional
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
import torch
|
| 19 |
+
import matplotlib.pyplot as plt
|
| 20 |
+
import seaborn as sns
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
import logging
|
| 23 |
+
|
| 24 |
+
# Setup logging
|
| 25 |
+
logging.basicConfig(
|
| 26 |
+
level=logging.INFO,
|
| 27 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 28 |
+
)
|
| 29 |
+
logger = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class BenchmarkRunner:
|
| 33 |
+
"""Run benchmarks comparing BeatHeritage V1 with Mapperatorinator V30"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, output_dir: str = "./benchmark_results"):
|
| 36 |
+
self.output_dir = Path(output_dir)
|
| 37 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 38 |
+
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 39 |
+
self.results = []
|
| 40 |
+
|
| 41 |
+
def run_inference(self, model_config: str, audio_path: str,
|
| 42 |
+
gamemode: int, difficulty: float) -> Dict:
|
| 43 |
+
"""Run inference with specified model and parameters"""
|
| 44 |
+
|
| 45 |
+
output_path = self.output_dir / f"{model_config}_{Path(audio_path).stem}"
|
| 46 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 47 |
+
|
| 48 |
+
cmd = [
|
| 49 |
+
'python', 'inference.py',
|
| 50 |
+
'-cn', model_config,
|
| 51 |
+
f'audio_path={audio_path}',
|
| 52 |
+
f'output_path={str(output_path)}',
|
| 53 |
+
f'gamemode={gamemode}',
|
| 54 |
+
f'difficulty={difficulty}',
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
# Add model-specific parameters
|
| 58 |
+
if model_config == 'beatheritage_v1':
|
| 59 |
+
cmd.extend([
|
| 60 |
+
'temperature=0.85',
|
| 61 |
+
'top_p=0.92',
|
| 62 |
+
'quality_control.enable_auto_correction=true',
|
| 63 |
+
'quality_control.enable_flow_optimization=true',
|
| 64 |
+
'advanced_features.enable_pattern_variety=true',
|
| 65 |
+
])
|
| 66 |
+
else: # v30
|
| 67 |
+
cmd.extend([
|
| 68 |
+
'temperature=0.9',
|
| 69 |
+
'top_p=0.9',
|
| 70 |
+
])
|
| 71 |
+
|
| 72 |
+
# Measure performance
|
| 73 |
+
start_time = time.time()
|
| 74 |
+
memory_before = self._get_memory_usage()
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
result = subprocess.run(
|
| 78 |
+
cmd,
|
| 79 |
+
capture_output=True,
|
| 80 |
+
text=True,
|
| 81 |
+
check=True
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
end_time = time.time()
|
| 85 |
+
memory_after = self._get_memory_usage()
|
| 86 |
+
|
| 87 |
+
# Parse output for quality metrics
|
| 88 |
+
output_files = list(output_path.glob('*.osu'))
|
| 89 |
+
|
| 90 |
+
metrics = {
|
| 91 |
+
'model': model_config,
|
| 92 |
+
'audio': Path(audio_path).name,
|
| 93 |
+
'gamemode': gamemode,
|
| 94 |
+
'difficulty': difficulty,
|
| 95 |
+
'generation_time': end_time - start_time,
|
| 96 |
+
'memory_usage': memory_after - memory_before,
|
| 97 |
+
'success': True,
|
| 98 |
+
'output_files': len(output_files),
|
| 99 |
+
'quality_metrics': self._analyze_quality(output_files[0] if output_files else None)
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
except subprocess.CalledProcessError as e:
|
| 103 |
+
logger.error(f"Error running {model_config}: {e}")
|
| 104 |
+
metrics = {
|
| 105 |
+
'model': model_config,
|
| 106 |
+
'audio': Path(audio_path).name,
|
| 107 |
+
'gamemode': gamemode,
|
| 108 |
+
'difficulty': difficulty,
|
| 109 |
+
'generation_time': -1,
|
| 110 |
+
'memory_usage': -1,
|
| 111 |
+
'success': False,
|
| 112 |
+
'error': str(e),
|
| 113 |
+
'output_files': 0,
|
| 114 |
+
'quality_metrics': {}
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
return metrics
|
| 118 |
+
|
| 119 |
+
def _get_memory_usage(self) -> float:
|
| 120 |
+
"""Get current GPU memory usage in MB"""
|
| 121 |
+
if torch.cuda.is_available():
|
| 122 |
+
return torch.cuda.memory_allocated() / 1024**2
|
| 123 |
+
return 0
|
| 124 |
+
|
| 125 |
+
def _analyze_quality(self, osu_file: Optional[Path]) -> Dict:
|
| 126 |
+
"""Analyze quality metrics of generated beatmap"""
|
| 127 |
+
if not osu_file or not osu_file.exists():
|
| 128 |
+
return {}
|
| 129 |
+
|
| 130 |
+
metrics = {
|
| 131 |
+
'object_count': 0,
|
| 132 |
+
'avg_spacing': 0,
|
| 133 |
+
'spacing_variance': 0,
|
| 134 |
+
'pattern_diversity': 0,
|
| 135 |
+
'flow_score': 0,
|
| 136 |
+
'difficulty_consistency': 0
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
try:
|
| 140 |
+
with open(osu_file, 'r', encoding='utf-8') as f:
|
| 141 |
+
lines = f.readlines()
|
| 142 |
+
|
| 143 |
+
# Parse hit objects
|
| 144 |
+
hit_objects = []
|
| 145 |
+
in_hit_objects = False
|
| 146 |
+
|
| 147 |
+
for line in lines:
|
| 148 |
+
if '[HitObjects]' in line:
|
| 149 |
+
in_hit_objects = True
|
| 150 |
+
continue
|
| 151 |
+
|
| 152 |
+
if in_hit_objects and line.strip():
|
| 153 |
+
parts = line.strip().split(',')
|
| 154 |
+
if len(parts) >= 2:
|
| 155 |
+
try:
|
| 156 |
+
x, y = int(parts[0]), int(parts[1])
|
| 157 |
+
hit_objects.append((x, y))
|
| 158 |
+
except:
|
| 159 |
+
pass
|
| 160 |
+
|
| 161 |
+
metrics['object_count'] = len(hit_objects)
|
| 162 |
+
|
| 163 |
+
if len(hit_objects) > 1:
|
| 164 |
+
# Calculate spacing metrics
|
| 165 |
+
distances = []
|
| 166 |
+
for i in range(1, len(hit_objects)):
|
| 167 |
+
dist = np.sqrt(
|
| 168 |
+
(hit_objects[i][0] - hit_objects[i-1][0])**2 +
|
| 169 |
+
(hit_objects[i][1] - hit_objects[i-1][1])**2
|
| 170 |
+
)
|
| 171 |
+
distances.append(dist)
|
| 172 |
+
|
| 173 |
+
metrics['avg_spacing'] = np.mean(distances)
|
| 174 |
+
metrics['spacing_variance'] = np.var(distances)
|
| 175 |
+
|
| 176 |
+
# Pattern diversity (entropy of distance distribution)
|
| 177 |
+
hist, _ = np.histogram(distances, bins=10)
|
| 178 |
+
hist = hist / hist.sum()
|
| 179 |
+
entropy = -np.sum(hist * np.log(hist + 1e-10))
|
| 180 |
+
metrics['pattern_diversity'] = entropy
|
| 181 |
+
|
| 182 |
+
# Flow score (based on angle changes)
|
| 183 |
+
if len(hit_objects) > 2:
|
| 184 |
+
angles = []
|
| 185 |
+
for i in range(2, len(hit_objects)):
|
| 186 |
+
angle = self._calculate_angle(
|
| 187 |
+
hit_objects[i-2],
|
| 188 |
+
hit_objects[i-1],
|
| 189 |
+
hit_objects[i]
|
| 190 |
+
)
|
| 191 |
+
angles.append(angle)
|
| 192 |
+
|
| 193 |
+
# Lower angle variance = better flow
|
| 194 |
+
metrics['flow_score'] = 1.0 / (1.0 + np.var(angles) / 100)
|
| 195 |
+
|
| 196 |
+
# Difficulty consistency
|
| 197 |
+
chunk_size = max(10, len(distances) // 10)
|
| 198 |
+
chunk_variances = []
|
| 199 |
+
for i in range(0, len(distances), chunk_size):
|
| 200 |
+
chunk = distances[i:i+chunk_size]
|
| 201 |
+
if chunk:
|
| 202 |
+
chunk_variances.append(np.var(chunk))
|
| 203 |
+
|
| 204 |
+
if chunk_variances:
|
| 205 |
+
metrics['difficulty_consistency'] = 1.0 / (1.0 + np.var(chunk_variances))
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
logger.error(f"Error analyzing quality: {e}")
|
| 209 |
+
|
| 210 |
+
return metrics
|
| 211 |
+
|
| 212 |
+
def _calculate_angle(self, p1: Tuple, p2: Tuple, p3: Tuple) -> float:
|
| 213 |
+
"""Calculate angle between three points"""
|
| 214 |
+
v1 = (p2[0] - p1[0], p2[1] - p1[1])
|
| 215 |
+
v2 = (p3[0] - p2[0], p3[1] - p2[1])
|
| 216 |
+
|
| 217 |
+
angle1 = np.arctan2(v1[1], v1[0])
|
| 218 |
+
angle2 = np.arctan2(v2[1], v2[0])
|
| 219 |
+
|
| 220 |
+
angle_diff = angle2 - angle1
|
| 221 |
+
# Normalize to [-pi, pi]
|
| 222 |
+
while angle_diff > np.pi:
|
| 223 |
+
angle_diff -= 2 * np.pi
|
| 224 |
+
while angle_diff < -np.pi:
|
| 225 |
+
angle_diff += 2 * np.pi
|
| 226 |
+
|
| 227 |
+
return abs(angle_diff)
|
| 228 |
+
|
| 229 |
+
def run_benchmark_suite(self, test_audio_files: List[str]):
|
| 230 |
+
"""Run complete benchmark suite"""
|
| 231 |
+
|
| 232 |
+
models = ['beatheritage_v1', 'v30']
|
| 233 |
+
gamemodes = [0, 1, 2, 3] # All gamemodes
|
| 234 |
+
difficulties = [3.0, 5.5, 7.5] # Easy, Normal, Hard
|
| 235 |
+
|
| 236 |
+
total_tests = len(test_audio_files) * len(models) * len(gamemodes) * len(difficulties)
|
| 237 |
+
|
| 238 |
+
with tqdm(total=total_tests, desc="Running benchmarks") as pbar:
|
| 239 |
+
for audio_file in test_audio_files:
|
| 240 |
+
for gamemode in gamemodes:
|
| 241 |
+
for difficulty in difficulties:
|
| 242 |
+
for model in models:
|
| 243 |
+
logger.info(f"Testing {model} on {audio_file} "
|
| 244 |
+
f"(GM:{gamemode}, Diff:{difficulty})")
|
| 245 |
+
|
| 246 |
+
result = self.run_inference(
|
| 247 |
+
model, audio_file, gamemode, difficulty
|
| 248 |
+
)
|
| 249 |
+
self.results.append(result)
|
| 250 |
+
pbar.update(1)
|
| 251 |
+
|
| 252 |
+
# Save intermediate results
|
| 253 |
+
self._save_results()
|
| 254 |
+
|
| 255 |
+
def _save_results(self):
|
| 256 |
+
"""Save benchmark results to JSON and CSV"""
|
| 257 |
+
# Save as JSON
|
| 258 |
+
json_path = self.output_dir / f"benchmark_results_{self.timestamp}.json"
|
| 259 |
+
with open(json_path, 'w') as f:
|
| 260 |
+
json.dump(self.results, f, indent=2)
|
| 261 |
+
|
| 262 |
+
# Save as CSV for analysis
|
| 263 |
+
df = pd.DataFrame(self.results)
|
| 264 |
+
csv_path = self.output_dir / f"benchmark_results_{self.timestamp}.csv"
|
| 265 |
+
df.to_csv(csv_path, index=False)
|
| 266 |
+
|
| 267 |
+
logger.info(f"Results saved to {json_path} and {csv_path}")
|
| 268 |
+
|
| 269 |
+
def generate_report(self):
|
| 270 |
+
"""Generate comprehensive benchmark report with visualizations"""
|
| 271 |
+
|
| 272 |
+
if not self.results:
|
| 273 |
+
logger.error("No results to generate report")
|
| 274 |
+
return
|
| 275 |
+
|
| 276 |
+
df = pd.DataFrame(self.results)
|
| 277 |
+
|
| 278 |
+
# Create visualizations
|
| 279 |
+
fig = plt.figure(figsize=(20, 12))
|
| 280 |
+
|
| 281 |
+
# 1. Generation Time Comparison
|
| 282 |
+
ax1 = plt.subplot(2, 3, 1)
|
| 283 |
+
successful_df = df[df['success'] == True]
|
| 284 |
+
if not successful_df.empty:
|
| 285 |
+
sns.boxplot(data=successful_df, x='model', y='generation_time', ax=ax1)
|
| 286 |
+
ax1.set_title('Generation Time Comparison')
|
| 287 |
+
ax1.set_ylabel('Time (seconds)')
|
| 288 |
+
|
| 289 |
+
# 2. Memory Usage Comparison
|
| 290 |
+
ax2 = plt.subplot(2, 3, 2)
|
| 291 |
+
if not successful_df.empty:
|
| 292 |
+
sns.boxplot(data=successful_df, x='model', y='memory_usage', ax=ax2)
|
| 293 |
+
ax2.set_title('Memory Usage Comparison')
|
| 294 |
+
ax2.set_ylabel('Memory (MB)')
|
| 295 |
+
|
| 296 |
+
# 3. Success Rate
|
| 297 |
+
ax3 = plt.subplot(2, 3, 3)
|
| 298 |
+
success_rates = df.groupby('model')['success'].mean() * 100
|
| 299 |
+
success_rates.plot(kind='bar', ax=ax3)
|
| 300 |
+
ax3.set_title('Success Rate (%)')
|
| 301 |
+
ax3.set_ylabel('Success Rate')
|
| 302 |
+
ax3.set_ylim(0, 105)
|
| 303 |
+
|
| 304 |
+
# 4. Quality Metrics Comparison
|
| 305 |
+
if not successful_df.empty and 'quality_metrics' in successful_df.columns:
|
| 306 |
+
# Extract quality metrics
|
| 307 |
+
quality_data = []
|
| 308 |
+
for _, row in successful_df.iterrows():
|
| 309 |
+
if row['quality_metrics']:
|
| 310 |
+
quality_data.append({
|
| 311 |
+
'model': row['model'],
|
| 312 |
+
'pattern_diversity': row['quality_metrics'].get('pattern_diversity', 0),
|
| 313 |
+
'flow_score': row['quality_metrics'].get('flow_score', 0),
|
| 314 |
+
'difficulty_consistency': row['quality_metrics'].get('difficulty_consistency', 0)
|
| 315 |
+
})
|
| 316 |
+
|
| 317 |
+
if quality_data:
|
| 318 |
+
quality_df = pd.DataFrame(quality_data)
|
| 319 |
+
|
| 320 |
+
# Pattern Diversity
|
| 321 |
+
ax4 = plt.subplot(2, 3, 4)
|
| 322 |
+
if 'pattern_diversity' in quality_df.columns:
|
| 323 |
+
sns.boxplot(data=quality_df, x='model', y='pattern_diversity', ax=ax4)
|
| 324 |
+
ax4.set_title('Pattern Diversity Score')
|
| 325 |
+
|
| 326 |
+
# Flow Score
|
| 327 |
+
ax5 = plt.subplot(2, 3, 5)
|
| 328 |
+
if 'flow_score' in quality_df.columns:
|
| 329 |
+
sns.boxplot(data=quality_df, x='model', y='flow_score', ax=ax5)
|
| 330 |
+
ax5.set_title('Flow Quality Score')
|
| 331 |
+
|
| 332 |
+
# Difficulty Consistency
|
| 333 |
+
ax6 = plt.subplot(2, 3, 6)
|
| 334 |
+
if 'difficulty_consistency' in quality_df.columns:
|
| 335 |
+
sns.boxplot(data=quality_df, x='model', y='difficulty_consistency', ax=ax6)
|
| 336 |
+
ax6.set_title('Difficulty Consistency Score')
|
| 337 |
+
|
| 338 |
+
plt.suptitle('BeatHeritage V1 vs Mapperatorinator V30 Benchmark Report', fontsize=16)
|
| 339 |
+
plt.tight_layout()
|
| 340 |
+
|
| 341 |
+
# Save plot
|
| 342 |
+
plot_path = self.output_dir / f"benchmark_report_{self.timestamp}.png"
|
| 343 |
+
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
|
| 344 |
+
plt.show()
|
| 345 |
+
|
| 346 |
+
# Generate text summary
|
| 347 |
+
summary = self._generate_text_summary(df)
|
| 348 |
+
summary_path = self.output_dir / f"benchmark_summary_{self.timestamp}.txt"
|
| 349 |
+
with open(summary_path, 'w') as f:
|
| 350 |
+
f.write(summary)
|
| 351 |
+
|
| 352 |
+
logger.info(f"Report generated: {plot_path} and {summary_path}")
|
| 353 |
+
|
| 354 |
+
def _generate_text_summary(self, df: pd.DataFrame) -> str:
|
| 355 |
+
"""Generate text summary of benchmark results"""
|
| 356 |
+
|
| 357 |
+
summary = []
|
| 358 |
+
summary.append("=" * 80)
|
| 359 |
+
summary.append("BEATHERITAGE V1 VS MAPPERATORINATOR V30 BENCHMARK SUMMARY")
|
| 360 |
+
summary.append("=" * 80)
|
| 361 |
+
summary.append(f"Timestamp: {self.timestamp}")
|
| 362 |
+
summary.append(f"Total Tests: {len(df)}")
|
| 363 |
+
summary.append("")
|
| 364 |
+
|
| 365 |
+
for model in df['model'].unique():
|
| 366 |
+
model_df = df[df['model'] == model]
|
| 367 |
+
successful_df = model_df[model_df['success'] == True]
|
| 368 |
+
|
| 369 |
+
summary.append(f"\n{model.upper()}")
|
| 370 |
+
summary.append("-" * 40)
|
| 371 |
+
summary.append(f"Success Rate: {model_df['success'].mean()*100:.1f}%")
|
| 372 |
+
|
| 373 |
+
if not successful_df.empty:
|
| 374 |
+
summary.append(f"Avg Generation Time: {successful_df['generation_time'].mean():.2f}s")
|
| 375 |
+
summary.append(f"Avg Memory Usage: {successful_df['memory_usage'].mean():.1f}MB")
|
| 376 |
+
|
| 377 |
+
# Quality metrics
|
| 378 |
+
quality_metrics = []
|
| 379 |
+
for _, row in successful_df.iterrows():
|
| 380 |
+
if row['quality_metrics']:
|
| 381 |
+
quality_metrics.append(row['quality_metrics'])
|
| 382 |
+
|
| 383 |
+
if quality_metrics:
|
| 384 |
+
avg_diversity = np.mean([m.get('pattern_diversity', 0) for m in quality_metrics])
|
| 385 |
+
avg_flow = np.mean([m.get('flow_score', 0) for m in quality_metrics])
|
| 386 |
+
avg_consistency = np.mean([m.get('difficulty_consistency', 0) for m in quality_metrics])
|
| 387 |
+
|
| 388 |
+
summary.append(f"Avg Pattern Diversity: {avg_diversity:.3f}")
|
| 389 |
+
summary.append(f"Avg Flow Score: {avg_flow:.3f}")
|
| 390 |
+
summary.append(f"Avg Difficulty Consistency: {avg_consistency:.3f}")
|
| 391 |
+
|
| 392 |
+
# Winner determination
|
| 393 |
+
summary.append("\n" + "=" * 80)
|
| 394 |
+
summary.append("WINNER ANALYSIS")
|
| 395 |
+
summary.append("=" * 80)
|
| 396 |
+
|
| 397 |
+
if len(df['model'].unique()) == 2:
|
| 398 |
+
model1, model2 = df['model'].unique()
|
| 399 |
+
|
| 400 |
+
# Compare metrics
|
| 401 |
+
metrics_comparison = []
|
| 402 |
+
|
| 403 |
+
for metric in ['generation_time', 'memory_usage']:
|
| 404 |
+
m1_avg = df[df['model'] == model1][metric].mean()
|
| 405 |
+
m2_avg = df[df['model'] == model2][metric].mean()
|
| 406 |
+
|
| 407 |
+
if m1_avg < m2_avg:
|
| 408 |
+
winner = model1
|
| 409 |
+
improvement = ((m2_avg - m1_avg) / m2_avg) * 100
|
| 410 |
+
else:
|
| 411 |
+
winner = model2
|
| 412 |
+
improvement = ((m1_avg - m2_avg) / m1_avg) * 100
|
| 413 |
+
|
| 414 |
+
metrics_comparison.append(
|
| 415 |
+
f"{metric}: {winner} ({improvement:.1f}% better)"
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
for comp in metrics_comparison:
|
| 419 |
+
summary.append(comp)
|
| 420 |
+
|
| 421 |
+
return "\n".join(summary)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def main():
|
| 425 |
+
parser = argparse.ArgumentParser(description='Benchmark BeatHeritage V1 vs V30')
|
| 426 |
+
parser.add_argument(
|
| 427 |
+
'--audio-dir',
|
| 428 |
+
type=str,
|
| 429 |
+
default='./test_audio',
|
| 430 |
+
help='Directory containing test audio files'
|
| 431 |
+
)
|
| 432 |
+
parser.add_argument(
|
| 433 |
+
'--output-dir',
|
| 434 |
+
type=str,
|
| 435 |
+
default='./benchmark_results',
|
| 436 |
+
help='Directory to save benchmark results'
|
| 437 |
+
)
|
| 438 |
+
parser.add_argument(
|
| 439 |
+
'--quick-test',
|
| 440 |
+
action='store_true',
|
| 441 |
+
help='Run quick test with limited parameters'
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
args = parser.parse_args()
|
| 445 |
+
|
| 446 |
+
# Get test audio files
|
| 447 |
+
audio_dir = Path(args.audio_dir)
|
| 448 |
+
if audio_dir.exists():
|
| 449 |
+
audio_files = list(audio_dir.glob('*.mp3')) + list(audio_dir.glob('*.ogg'))
|
| 450 |
+
else:
|
| 451 |
+
# Use demo files
|
| 452 |
+
logger.warning(f"Audio directory {audio_dir} not found, using demo files")
|
| 453 |
+
audio_files = ['demo.mp3'] # Fallback to demo
|
| 454 |
+
|
| 455 |
+
if args.quick_test:
|
| 456 |
+
# Quick test with limited parameters
|
| 457 |
+
audio_files = audio_files[:1]
|
| 458 |
+
logger.info("Running quick test with 1 audio file")
|
| 459 |
+
|
| 460 |
+
# Run benchmarks
|
| 461 |
+
runner = BenchmarkRunner(args.output_dir)
|
| 462 |
+
runner.run_benchmark_suite([str(f) for f in audio_files])
|
| 463 |
+
runner.generate_report()
|
| 464 |
+
|
| 465 |
+
logger.info("Benchmark complete!")
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
if __name__ == "__main__":
|
| 469 |
+
main()
|
calc_fid.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import traceback
|
| 5 |
+
from datetime import timedelta
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import hydra
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from scipy import linalg
|
| 13 |
+
from slider import Beatmap, Circle, Slider, Spinner, HoldNote
|
| 14 |
+
from torch.utils.data import DataLoader
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
from classifier.classify import ExampleDataset
|
| 18 |
+
from classifier.libs.model.model import OsuClassifierOutput
|
| 19 |
+
from classifier.libs.utils import load_ckpt
|
| 20 |
+
from config import FidConfig
|
| 21 |
+
from inference import prepare_args, load_diff_model, generate, load_model
|
| 22 |
+
from osuT5.osuT5.dataset.data_utils import load_audio_file, load_mmrs_metadata, filter_mmrs_metadata
|
| 23 |
+
from osuT5.osuT5.inference import generation_config_from_beatmap, beatmap_config_from_beatmap
|
| 24 |
+
from osuT5.osuT5.tokenizer import ContextType
|
| 25 |
+
from multiprocessing import Manager, Process
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_beatmap_paths(args: FidConfig) -> list[Path]:
|
| 31 |
+
"""Get all beatmap paths (.osu) from the dataset directory."""
|
| 32 |
+
dataset_path = Path(args.dataset_path)
|
| 33 |
+
|
| 34 |
+
if args.dataset_type == "mmrs":
|
| 35 |
+
metadata = load_mmrs_metadata(dataset_path)
|
| 36 |
+
filtered_metadata = filter_mmrs_metadata(
|
| 37 |
+
metadata,
|
| 38 |
+
start=args.dataset_start,
|
| 39 |
+
end=args.dataset_end,
|
| 40 |
+
gamemodes=args.gamemodes,
|
| 41 |
+
)
|
| 42 |
+
beatmap_files = [dataset_path / "data" / item["BeatmapSetFolder"] / item["BeatmapFile"] for _, item in filtered_metadata.iterrows()]
|
| 43 |
+
elif args.dataset_type == "ors":
|
| 44 |
+
beatmap_files = []
|
| 45 |
+
track_names = ["Track" + str(i).zfill(5) for i in range(args.dataset_start, args.dataset_end)]
|
| 46 |
+
for track_name in track_names:
|
| 47 |
+
for beatmap_file in (dataset_path / track_name / "beatmaps").iterdir():
|
| 48 |
+
beatmap_files.append(dataset_path / track_name / "beatmaps" / beatmap_file.name)
|
| 49 |
+
else:
|
| 50 |
+
raise ValueError(f"Unknown dataset type: {args.dataset_type}")
|
| 51 |
+
|
| 52 |
+
return beatmap_files
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
| 56 |
+
"""Numpy implementation of the Frechet Distance.
|
| 57 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
| 58 |
+
and X_2 ~ N(mu_2, C_2) is
|
| 59 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
| 60 |
+
|
| 61 |
+
Stable version by Dougal J. Sutherland.
|
| 62 |
+
|
| 63 |
+
Params:
|
| 64 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
| 65 |
+
inception net (like returned by the function 'get_predictions')
|
| 66 |
+
for generated samples.
|
| 67 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
| 68 |
+
representative data set.
|
| 69 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
| 70 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
| 71 |
+
representative data set.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
-- : The Frechet Distance.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
mu1 = np.atleast_1d(mu1)
|
| 78 |
+
mu2 = np.atleast_1d(mu2)
|
| 79 |
+
|
| 80 |
+
sigma1 = np.atleast_2d(sigma1)
|
| 81 |
+
sigma2 = np.atleast_2d(sigma2)
|
| 82 |
+
|
| 83 |
+
assert (
|
| 84 |
+
mu1.shape == mu2.shape
|
| 85 |
+
), "Training and test mean vectors have different lengths"
|
| 86 |
+
assert (
|
| 87 |
+
sigma1.shape == sigma2.shape
|
| 88 |
+
), "Training and test covariances have different dimensions"
|
| 89 |
+
|
| 90 |
+
diff = mu1 - mu2
|
| 91 |
+
|
| 92 |
+
# Product might be almost singular
|
| 93 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
| 94 |
+
if not np.isfinite(covmean).all():
|
| 95 |
+
msg = (
|
| 96 |
+
"fid calculation produces singular product; "
|
| 97 |
+
"adding %s to diagonal of cov estimates"
|
| 98 |
+
) % eps
|
| 99 |
+
logger.warning(msg)
|
| 100 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 101 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 102 |
+
|
| 103 |
+
# Numerical error might give slight imaginary component
|
| 104 |
+
if np.iscomplexobj(covmean):
|
| 105 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 106 |
+
m = np.max(np.abs(covmean.imag))
|
| 107 |
+
raise ValueError("Imaginary component {}".format(m))
|
| 108 |
+
covmean = covmean.real
|
| 109 |
+
|
| 110 |
+
tr_covmean = np.trace(covmean)
|
| 111 |
+
|
| 112 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def add_to_dict(source_dict, target_dict):
|
| 116 |
+
for key, value in source_dict.items():
|
| 117 |
+
if key not in target_dict:
|
| 118 |
+
target_dict[key] = value
|
| 119 |
+
else:
|
| 120 |
+
target_dict[key] += value
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def calculate_rhythm_stats(real_rhythm, generated_rhythm):
|
| 124 |
+
# Rhythm is a set of timestamps for each beat
|
| 125 |
+
# Calculate number of true positives, false positives, and false negatives within a leniency of 10 ms
|
| 126 |
+
leniency = 10
|
| 127 |
+
true_positives = 0
|
| 128 |
+
false_positives = 0
|
| 129 |
+
false_negatives = 0
|
| 130 |
+
for real_beat in real_rhythm:
|
| 131 |
+
if any(abs(real_beat - gen_beat) <= leniency for gen_beat in generated_rhythm):
|
| 132 |
+
true_positives += 1
|
| 133 |
+
else:
|
| 134 |
+
false_negatives += 1
|
| 135 |
+
|
| 136 |
+
for gen_beat in generated_rhythm:
|
| 137 |
+
if not any(abs(gen_beat - real_beat) <= leniency for real_beat in real_rhythm):
|
| 138 |
+
false_positives += 1
|
| 139 |
+
|
| 140 |
+
return {
|
| 141 |
+
"true_positives": true_positives,
|
| 142 |
+
"false_positives": false_positives,
|
| 143 |
+
"false_negatives": false_negatives,
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def calculate_precision(rhythm_stats):
|
| 148 |
+
true_positives = rhythm_stats["true_positives"]
|
| 149 |
+
false_positives = rhythm_stats["false_positives"]
|
| 150 |
+
if true_positives + false_positives == 0:
|
| 151 |
+
return 0.0
|
| 152 |
+
return true_positives / (true_positives + false_positives)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def calculate_recall(rhythm_stats):
|
| 156 |
+
true_positives = rhythm_stats["true_positives"]
|
| 157 |
+
false_negatives = rhythm_stats["false_negatives"]
|
| 158 |
+
if true_positives + false_negatives == 0:
|
| 159 |
+
return 0.0
|
| 160 |
+
return true_positives / (true_positives + false_negatives)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def calculate_f1(rhythm_stats):
|
| 164 |
+
precision = calculate_precision(rhythm_stats)
|
| 165 |
+
recall = calculate_recall(rhythm_stats)
|
| 166 |
+
if precision + recall == 0:
|
| 167 |
+
return 0.0
|
| 168 |
+
return 2 * (precision * recall) / (precision + recall)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def get_rhythm(beatmap, passive=False):
|
| 172 |
+
# Extract the rhythm from the beatmap
|
| 173 |
+
# Active rhythm includes only circles, slider heads, and hold note heads
|
| 174 |
+
# Passive rhythm also includes slider tails, slider repeats, and spinners tails
|
| 175 |
+
rhythm = set()
|
| 176 |
+
for hit_object in beatmap.hit_objects(stacking=False):
|
| 177 |
+
if isinstance(hit_object, Circle):
|
| 178 |
+
rhythm.add(int(hit_object.time.total_seconds() * 1000 + 1e-5))
|
| 179 |
+
elif isinstance(hit_object, Slider):
|
| 180 |
+
duration: timedelta = (hit_object.end_time - hit_object.time) / hit_object.repeat
|
| 181 |
+
rhythm.add(int(hit_object.time.total_seconds() * 1000 + 1e-5))
|
| 182 |
+
if passive:
|
| 183 |
+
for i in range(hit_object.repeat):
|
| 184 |
+
rhythm.add(int((hit_object.time + duration * (i + 1)).total_seconds() * 1000 + 1e-5))
|
| 185 |
+
elif isinstance(hit_object, Spinner):
|
| 186 |
+
if passive:
|
| 187 |
+
rhythm.add(int(hit_object.end_time.total_seconds() * 1000 + 1e-5))
|
| 188 |
+
elif isinstance(hit_object, HoldNote):
|
| 189 |
+
rhythm.add(int(hit_object.time.total_seconds() * 1000 + 1e-5))
|
| 190 |
+
|
| 191 |
+
return rhythm
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def generate_beatmaps(beatmap_paths, fid_args: FidConfig, return_dict, idx):
|
| 195 |
+
args = fid_args.inference
|
| 196 |
+
args.device = fid_args.device
|
| 197 |
+
torch.set_grad_enabled(False)
|
| 198 |
+
torch.set_float32_matmul_precision('high')
|
| 199 |
+
|
| 200 |
+
model, tokenizer, diff_model, diff_tokenizer, refine_model = None, None, None, None, None
|
| 201 |
+
model, tokenizer = load_model(args.model_path, args.train, args.device, args.max_batch_size, args.use_server, args.precision)
|
| 202 |
+
|
| 203 |
+
if args.compile:
|
| 204 |
+
model.transformer.forward = torch.compile(model.transformer.forward, mode="reduce-overhead", fullgraph=True)
|
| 205 |
+
|
| 206 |
+
if args.generate_positions:
|
| 207 |
+
diff_model, diff_tokenizer = load_diff_model(args.diff_ckpt, args.diffusion, args.device)
|
| 208 |
+
|
| 209 |
+
if os.path.exists(args.diff_refine_ckpt):
|
| 210 |
+
refine_model = load_diff_model(args.diff_refine_ckpt, args.diffusion, args.device)[0]
|
| 211 |
+
|
| 212 |
+
if args.compile:
|
| 213 |
+
diff_model.forward = torch.compile(diff_model.forward, mode="reduce-overhead", fullgraph=False)
|
| 214 |
+
|
| 215 |
+
for beatmap_path in tqdm(beatmap_paths, desc=f"Process {idx}"):
|
| 216 |
+
try:
|
| 217 |
+
beatmap = Beatmap.from_path(beatmap_path)
|
| 218 |
+
output_path = Path("generated") / beatmap_path.stem
|
| 219 |
+
|
| 220 |
+
if fid_args.dataset_type == "ors":
|
| 221 |
+
audio_path = beatmap_path.parents[1] / list(beatmap_path.parents[1].glob('audio.*'))[0]
|
| 222 |
+
else:
|
| 223 |
+
audio_path = beatmap_path.parent / beatmap.audio_filename
|
| 224 |
+
|
| 225 |
+
if fid_args.skip_generation or (output_path.exists() and len(list(output_path.glob("*.osu"))) > 0):
|
| 226 |
+
if not output_path.exists() or len(list(output_path.glob("*.osu"))) == 0:
|
| 227 |
+
raise FileNotFoundError(f"Generated beatmap not found in {output_path}")
|
| 228 |
+
print(f"Skipping {beatmap_path.stem} as it already exists")
|
| 229 |
+
else:
|
| 230 |
+
if ContextType.GD in args.in_context:
|
| 231 |
+
other_beatmaps = [k for k in beatmap_path.parent.glob("*.osu") if k != beatmap_path]
|
| 232 |
+
if len(other_beatmaps) == 0:
|
| 233 |
+
continue
|
| 234 |
+
other_beatmap_path = random.choice(other_beatmaps)
|
| 235 |
+
else:
|
| 236 |
+
other_beatmap_path = beatmap_path
|
| 237 |
+
|
| 238 |
+
generation_config = generation_config_from_beatmap(beatmap, tokenizer)
|
| 239 |
+
beatmap_config = beatmap_config_from_beatmap(beatmap)
|
| 240 |
+
beatmap_config.version = args.version
|
| 241 |
+
|
| 242 |
+
if args.year is not None:
|
| 243 |
+
generation_config.year = args.year
|
| 244 |
+
|
| 245 |
+
result = generate(
|
| 246 |
+
args,
|
| 247 |
+
audio_path=audio_path,
|
| 248 |
+
beatmap_path=other_beatmap_path,
|
| 249 |
+
output_path=output_path,
|
| 250 |
+
generation_config=generation_config,
|
| 251 |
+
beatmap_config=beatmap_config,
|
| 252 |
+
model=model,
|
| 253 |
+
tokenizer=tokenizer,
|
| 254 |
+
diff_model=diff_model,
|
| 255 |
+
diff_tokenizer=diff_tokenizer,
|
| 256 |
+
refine_model=refine_model,
|
| 257 |
+
verbose=False,
|
| 258 |
+
)[0]
|
| 259 |
+
generated_beatmap = Beatmap.parse(result)
|
| 260 |
+
print(beatmap_path, "Generated %s hit objects" % len(generated_beatmap.hit_objects(stacking=False)))
|
| 261 |
+
except Exception as e:
|
| 262 |
+
print(f"Error processing {beatmap_path}: {e}")
|
| 263 |
+
traceback.print_exc()
|
| 264 |
+
finally:
|
| 265 |
+
torch.cuda.empty_cache() # Clear any cached memory
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def calculate_metrics(args: FidConfig, beatmap_paths: list[Path]):
|
| 269 |
+
print("Calculating metrics...")
|
| 270 |
+
|
| 271 |
+
classifier_model, classifier_args, classifier_tokenizer = None, None, None
|
| 272 |
+
if args.fid:
|
| 273 |
+
classifier_model, classifier_args, classifier_tokenizer = load_ckpt(args.classifier_ckpt)
|
| 274 |
+
|
| 275 |
+
if args.compile:
|
| 276 |
+
classifier_model.model.transformer.forward = torch.compile(classifier_model.model.transformer.forward,
|
| 277 |
+
mode="reduce-overhead", fullgraph=False)
|
| 278 |
+
|
| 279 |
+
real_features = []
|
| 280 |
+
generated_features = []
|
| 281 |
+
active_rhythm_stats = {}
|
| 282 |
+
passive_rhythm_stats = {}
|
| 283 |
+
|
| 284 |
+
for beatmap_path in tqdm(beatmap_paths, desc=f"Metrics"):
|
| 285 |
+
try:
|
| 286 |
+
beatmap = Beatmap.from_path(beatmap_path)
|
| 287 |
+
generated_path = Path("generated") / beatmap_path.stem
|
| 288 |
+
|
| 289 |
+
if args.dataset_type == "ors":
|
| 290 |
+
audio_path = beatmap_path.parents[1] / list(beatmap_path.parents[1].glob('audio.*'))[0]
|
| 291 |
+
else:
|
| 292 |
+
audio_path = beatmap_path.parent / beatmap.audio_filename
|
| 293 |
+
|
| 294 |
+
if generated_path.exists() and len(list(generated_path.glob("*.osu"))) > 0:
|
| 295 |
+
generated_beatmap = Beatmap.from_path(list(generated_path.glob("*.osu"))[0])
|
| 296 |
+
else:
|
| 297 |
+
logger.warning(f"Skipping {beatmap_path.stem} as no generated beatmap found")
|
| 298 |
+
continue
|
| 299 |
+
|
| 300 |
+
if args.fid:
|
| 301 |
+
# Calculate feature vectors for real and generated beatmaps
|
| 302 |
+
sample_rate = classifier_args.data.sample_rate
|
| 303 |
+
audio = load_audio_file(audio_path, sample_rate, normalize=args.inference.train.data.normalize_audio)
|
| 304 |
+
|
| 305 |
+
for example in DataLoader(
|
| 306 |
+
ExampleDataset(beatmap, audio, classifier_args, classifier_tokenizer, args.device),
|
| 307 |
+
batch_size=args.classifier_batch_size):
|
| 308 |
+
classifier_result: OsuClassifierOutput = classifier_model(**example)
|
| 309 |
+
features = classifier_result.feature_vector
|
| 310 |
+
real_features.append(features.cpu().numpy())
|
| 311 |
+
|
| 312 |
+
for example in DataLoader(
|
| 313 |
+
ExampleDataset(generated_beatmap, audio, classifier_args, classifier_tokenizer, args.device),
|
| 314 |
+
batch_size=args.classifier_batch_size):
|
| 315 |
+
classifier_result: OsuClassifierOutput = classifier_model(**example)
|
| 316 |
+
features = classifier_result.feature_vector
|
| 317 |
+
generated_features.append(features.cpu().numpy())
|
| 318 |
+
|
| 319 |
+
if args.rhythm_stats:
|
| 320 |
+
# Calculate rhythm stats
|
| 321 |
+
real_active_rhythm = get_rhythm(beatmap, passive=False)
|
| 322 |
+
generated_active_rhythm = get_rhythm(generated_beatmap, passive=False)
|
| 323 |
+
add_to_dict(calculate_rhythm_stats(real_active_rhythm, generated_active_rhythm), active_rhythm_stats)
|
| 324 |
+
|
| 325 |
+
real_passive_rhythm = get_rhythm(beatmap, passive=True)
|
| 326 |
+
generated_passive_rhythm = get_rhythm(generated_beatmap, passive=True)
|
| 327 |
+
add_to_dict(calculate_rhythm_stats(real_passive_rhythm, generated_passive_rhythm), passive_rhythm_stats)
|
| 328 |
+
except Exception as e:
|
| 329 |
+
print(f"Error processing {beatmap_path}: {e}")
|
| 330 |
+
traceback.print_exc()
|
| 331 |
+
finally:
|
| 332 |
+
torch.cuda.empty_cache() # Clear any cached memory
|
| 333 |
+
|
| 334 |
+
if args.fid:
|
| 335 |
+
# Calculate FID
|
| 336 |
+
real_features = np.concatenate(real_features, axis=0)
|
| 337 |
+
generated_features = np.concatenate(generated_features, axis=0)
|
| 338 |
+
m1, s1 = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
|
| 339 |
+
m2, s2 = np.mean(generated_features, axis=0), np.cov(generated_features, rowvar=False)
|
| 340 |
+
fid = calculate_frechet_distance(m1, s1, m2, s2)
|
| 341 |
+
|
| 342 |
+
logger.info(f"FID: {fid}")
|
| 343 |
+
|
| 344 |
+
if args.rhythm_stats:
|
| 345 |
+
# Calculate rhythm precision, recall, and F1 score
|
| 346 |
+
active_precision = calculate_precision(active_rhythm_stats)
|
| 347 |
+
active_recall = calculate_recall(active_rhythm_stats)
|
| 348 |
+
active_f1 = calculate_f1(active_rhythm_stats)
|
| 349 |
+
passive_precision = calculate_precision(passive_rhythm_stats)
|
| 350 |
+
passive_recall = calculate_recall(passive_rhythm_stats)
|
| 351 |
+
passive_f1 = calculate_f1(passive_rhythm_stats)
|
| 352 |
+
logger.info(f"Active Rhythm Precision: {active_precision}")
|
| 353 |
+
logger.info(f"Active Rhythm Recall: {active_recall}")
|
| 354 |
+
logger.info(f"Active Rhythm F1: {active_f1}")
|
| 355 |
+
logger.info(f"Passive Rhythm Precision: {passive_precision}")
|
| 356 |
+
logger.info(f"Passive Rhythm Recall: {passive_recall}")
|
| 357 |
+
logger.info(f"Passive Rhythm F1: {passive_f1}")
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def test_training_set_overlap(beatmap_paths: list[Path], training_set_ids_path: Optional[str]):
|
| 361 |
+
if training_set_ids_path is None:
|
| 362 |
+
return
|
| 363 |
+
|
| 364 |
+
if not os.path.exists(training_set_ids_path):
|
| 365 |
+
logger.error(f"Training set IDs file {training_set_ids_path} does not exist.")
|
| 366 |
+
return
|
| 367 |
+
|
| 368 |
+
with open(training_set_ids_path, "r") as f:
|
| 369 |
+
training_set_ids = set(int(line.strip()) for line in f)
|
| 370 |
+
|
| 371 |
+
in_set = 0
|
| 372 |
+
out_set = 0
|
| 373 |
+
for path in tqdm(beatmap_paths):
|
| 374 |
+
beatmap = Beatmap.from_path(path)
|
| 375 |
+
if beatmap.beatmap_id in training_set_ids:
|
| 376 |
+
in_set += 1
|
| 377 |
+
else:
|
| 378 |
+
out_set += 1
|
| 379 |
+
logger.info(f"In training set: {in_set}, Not in training set: {out_set}, Total: {len(beatmap_paths)}, Ratio: {in_set / (in_set + out_set):.2f}")
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
@hydra.main(config_path="configs", config_name="calc_fid", version_base="1.1")
|
| 383 |
+
def main(args: FidConfig):
|
| 384 |
+
prepare_args(args)
|
| 385 |
+
|
| 386 |
+
# Fix inference model path
|
| 387 |
+
if args.inference.model_path.startswith("./"):
|
| 388 |
+
args.inference.model_path = os.path.join(Path(__file__).parent, args.inference.model_path[2:])
|
| 389 |
+
|
| 390 |
+
beatmap_paths = get_beatmap_paths(args)
|
| 391 |
+
|
| 392 |
+
test_training_set_overlap(beatmap_paths, args.training_set_ids_path)
|
| 393 |
+
|
| 394 |
+
if not args.skip_generation:
|
| 395 |
+
# Assign beatmaps to processes in a round-robin fashion
|
| 396 |
+
num_processes = args.num_processes
|
| 397 |
+
chunks = [[] for _ in range(num_processes)]
|
| 398 |
+
for i, path in enumerate(beatmap_paths):
|
| 399 |
+
chunks[i % num_processes].append(path)
|
| 400 |
+
|
| 401 |
+
manager = Manager()
|
| 402 |
+
return_dict = manager.dict()
|
| 403 |
+
processes = []
|
| 404 |
+
|
| 405 |
+
for i in range(num_processes):
|
| 406 |
+
p = Process(target=generate_beatmaps, args=(chunks[i], args, return_dict, i))
|
| 407 |
+
processes.append(p)
|
| 408 |
+
p.start()
|
| 409 |
+
|
| 410 |
+
for p in processes:
|
| 411 |
+
p.join()
|
| 412 |
+
|
| 413 |
+
calculate_metrics(args, beatmap_paths)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
if __name__ == "__main__":
|
| 417 |
+
main()
|
classifier/README.md
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mapper Classifier
|
| 2 |
+
|
| 3 |
+
Try the model [here](https://colab.research.google.com/github/OliBomby/Mapperatorinator/blob/main/colab/classifier_classify.ipynb).
|
| 4 |
+
|
| 5 |
+
Mapper Classifier is a model that predicts which osu! standard ranked mapper mapped a given beatmap.
|
| 6 |
+
|
| 7 |
+
This model is built using transfer learning on the Mapperatorinator V22 model.
|
| 8 |
+
It achieves a top-1 validation accuracy of 12.5% on a random sample of ranked beatmaps and recognizes 3,731 unique mappers.
|
| 9 |
+
To make its predictions, the model analyzes an 8-second segment of beatmap.
|
| 10 |
+
|
| 11 |
+
The purpose of this classifier is actually to calculate high-level feature vectors for beatmaps, which can be used to calculate the similarity between generated beatmaps and real beatmaps.
|
| 12 |
+
This is a technique often used to assess the quality of image generation models with the [Fréchet Inception Distance](https://arxiv.org/abs/1706.08500).
|
| 13 |
+
However, in my testing I found that the computed FID scores for beatmap generation models were not very close to the actual quality of the generated beatmaps.
|
| 14 |
+
This classifier might not be able to recognize all the necessary features to accurately assess the quality of a beatmap, but it's a start.
|
| 15 |
+
|
| 16 |
+
## Usage
|
| 17 |
+
|
| 18 |
+
Run `classify.py` with the path to the beatmap you want to classify and the time in seconds of the segment you want to use to classify the beatmap.
|
| 19 |
+
```shell
|
| 20 |
+
python classify.py beatmap_path="'...\Songs\1790119 THE ORAL CIGARETTES - ReI\THE ORAL CIGARETTES - ReI (Sotarks) [Cataclysm.].osu'" time=60
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
```
|
| 24 |
+
Mapper: Sotarks (4452992) with confidence: 9.760356903076172
|
| 25 |
+
Mapper: Sajinn (13513687) with confidence: 6.975161075592041
|
| 26 |
+
Mapper: kowari (5404892) with confidence: 6.800069332122803
|
| 27 |
+
Mapper: Haruto (3772301) with confidence: 6.077754020690918
|
| 28 |
+
Mapper: Kalibe (3376777) with confidence: 5.894346237182617
|
| 29 |
+
Mapper: iljaaz (8501291) with confidence: 5.873990535736084
|
| 30 |
+
Mapper: tomadoi (5712451) with confidence: 5.817874431610107
|
| 31 |
+
Mapper: Nao Tomori (5364763) with confidence: 5.144880294799805
|
| 32 |
+
Mapper: Kujinn (3723568) with confidence: 5.082106590270996
|
| 33 |
+
...
|
| 34 |
+
```
|
classifier/classify.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import numpy.typing as npt
|
| 4 |
+
|
| 5 |
+
import hydra
|
| 6 |
+
import torch
|
| 7 |
+
from omegaconf import DictConfig
|
| 8 |
+
from slider import Beatmap
|
| 9 |
+
from torch.utils.data import IterableDataset
|
| 10 |
+
|
| 11 |
+
from classifier.libs.dataset import OsuParser
|
| 12 |
+
from classifier.libs.dataset.data_utils import load_audio_file
|
| 13 |
+
from classifier.libs.dataset.ors_dataset import STEPS_PER_MILLISECOND
|
| 14 |
+
from classifier.libs.model.model import OsuClassifierOutput
|
| 15 |
+
from classifier.libs.tokenizer import Tokenizer, Event, EventType
|
| 16 |
+
from classifier.libs.utils import load_ckpt
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def iterate_examples(
|
| 20 |
+
beatmap: Beatmap,
|
| 21 |
+
audio: npt.NDArray,
|
| 22 |
+
model_args: DictConfig,
|
| 23 |
+
tokenizer: Tokenizer,
|
| 24 |
+
device: torch.device
|
| 25 |
+
):
|
| 26 |
+
frame_seq_len = model_args.data.src_seq_len - 1
|
| 27 |
+
frame_size = model_args.data.hop_length
|
| 28 |
+
sample_rate = model_args.data.sample_rate
|
| 29 |
+
samples_per_sequence = frame_seq_len * frame_size
|
| 30 |
+
|
| 31 |
+
parser = OsuParser(model_args, tokenizer)
|
| 32 |
+
events, event_times = parser.parse(beatmap)
|
| 33 |
+
|
| 34 |
+
for sample in range(0, len(audio) - samples_per_sequence, samples_per_sequence):
|
| 35 |
+
example = create_example(events, event_times, audio, sample / sample_rate, model_args, tokenizer, device)
|
| 36 |
+
yield example
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ExampleDataset(IterableDataset):
|
| 40 |
+
def __init__(self, beatmap, audio, classifier_args, classifier_tokenizer, device):
|
| 41 |
+
self.beatmap = beatmap
|
| 42 |
+
self.audio = audio
|
| 43 |
+
self.classifier_args = classifier_args
|
| 44 |
+
self.classifier_tokenizer = classifier_tokenizer
|
| 45 |
+
self.device = device
|
| 46 |
+
|
| 47 |
+
def __iter__(self):
|
| 48 |
+
return iterate_examples(
|
| 49 |
+
self.beatmap,
|
| 50 |
+
self.audio,
|
| 51 |
+
self.classifier_args,
|
| 52 |
+
self.classifier_tokenizer,
|
| 53 |
+
self.device
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def create_example(
|
| 58 |
+
events: list[Event],
|
| 59 |
+
event_times: list[float],
|
| 60 |
+
audio: npt.NDArray,
|
| 61 |
+
time: float,
|
| 62 |
+
model_args: DictConfig,
|
| 63 |
+
tokenizer: Tokenizer,
|
| 64 |
+
device: torch.device,
|
| 65 |
+
unsqueeze: bool = False,
|
| 66 |
+
):
|
| 67 |
+
frame_seq_len = model_args.data.src_seq_len - 1
|
| 68 |
+
frame_size = model_args.data.hop_length
|
| 69 |
+
sample_rate = model_args.data.sample_rate
|
| 70 |
+
samples_per_sequence = frame_seq_len * frame_size
|
| 71 |
+
sequence_duration = samples_per_sequence / sample_rate
|
| 72 |
+
|
| 73 |
+
# Get audio frames
|
| 74 |
+
frame_start = int(time * sample_rate)
|
| 75 |
+
frames = audio[frame_start:frame_start + samples_per_sequence]
|
| 76 |
+
frames = torch.from_numpy(frames).to(torch.float32).to(device)
|
| 77 |
+
|
| 78 |
+
# Get the events between time and time + sequence_duration
|
| 79 |
+
events = [event for event, event_time in zip(events, event_times) if
|
| 80 |
+
time <= event_time / 1000 < time + sequence_duration]
|
| 81 |
+
# Normalize time shifts
|
| 82 |
+
for i, event in enumerate(events):
|
| 83 |
+
if event.type == EventType.TIME_SHIFT:
|
| 84 |
+
events[i] = Event(EventType.TIME_SHIFT, int((event.value - time * 1000) * STEPS_PER_MILLISECOND))
|
| 85 |
+
|
| 86 |
+
# Tokenize the events
|
| 87 |
+
tokens = torch.full((model_args.data.tgt_seq_len,), tokenizer.pad_id, dtype=torch.long)
|
| 88 |
+
for i in range(min(len(events), model_args.data.tgt_seq_len)):
|
| 89 |
+
tokens[i] = tokenizer.encode(events[i])
|
| 90 |
+
tokens = tokens.to(device)
|
| 91 |
+
|
| 92 |
+
if unsqueeze:
|
| 93 |
+
tokens = tokens.unsqueeze(0)
|
| 94 |
+
frames = frames.unsqueeze(0)
|
| 95 |
+
|
| 96 |
+
return {
|
| 97 |
+
"decoder_input_ids": tokens,
|
| 98 |
+
"decoder_attention_mask": tokens != tokenizer.pad_id,
|
| 99 |
+
"frames": frames,
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def create_example_from_path(
|
| 104 |
+
beatmap_path: str,
|
| 105 |
+
audio_path: str,
|
| 106 |
+
time: float,
|
| 107 |
+
model_args: DictConfig,
|
| 108 |
+
tokenizer: Tokenizer,
|
| 109 |
+
device: torch.device,
|
| 110 |
+
unsqueeze: bool = False,
|
| 111 |
+
):
|
| 112 |
+
sample_rate = model_args.data.sample_rate
|
| 113 |
+
|
| 114 |
+
beatmap_path = Path(beatmap_path)
|
| 115 |
+
beatmap = Beatmap.from_path(beatmap_path)
|
| 116 |
+
|
| 117 |
+
# Get audio frames
|
| 118 |
+
if audio_path == '':
|
| 119 |
+
audio_path = beatmap_path.parent / beatmap.audio_filename
|
| 120 |
+
|
| 121 |
+
audio = load_audio_file(audio_path, sample_rate)
|
| 122 |
+
|
| 123 |
+
parser = OsuParser(model_args, tokenizer)
|
| 124 |
+
events, event_times = parser.parse(beatmap)
|
| 125 |
+
|
| 126 |
+
return create_example(events, event_times, audio, time, model_args, tokenizer, device, unsqueeze)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_mapper_names(path: str):
|
| 130 |
+
path = Path(path)
|
| 131 |
+
|
| 132 |
+
# Load JSON data from file
|
| 133 |
+
with open(path, 'r') as file:
|
| 134 |
+
data = json.load(file)
|
| 135 |
+
|
| 136 |
+
# Populate beatmap_mapper
|
| 137 |
+
mapper_names = {}
|
| 138 |
+
for item in data:
|
| 139 |
+
if len(item['username']) == 0:
|
| 140 |
+
mapper_name = "Unknown"
|
| 141 |
+
else:
|
| 142 |
+
mapper_name = item['username'][0]
|
| 143 |
+
mapper_names[item['user_id']] = mapper_name
|
| 144 |
+
|
| 145 |
+
return mapper_names
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@hydra.main(config_path="configs", config_name="inference", version_base="1.1")
|
| 149 |
+
def main(args: DictConfig):
|
| 150 |
+
torch.set_grad_enabled(False)
|
| 151 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 152 |
+
|
| 153 |
+
model, model_args, tokenizer = load_ckpt(args.checkpoint_path)
|
| 154 |
+
model.eval().to(device)
|
| 155 |
+
|
| 156 |
+
example = create_example_from_path(args.beatmap_path, args.audio_path, args.time, model_args, tokenizer, device, True)
|
| 157 |
+
result: OsuClassifierOutput = model(**example)
|
| 158 |
+
logits = result.logits
|
| 159 |
+
|
| 160 |
+
# Print the top 100 mappers with confidences
|
| 161 |
+
top_k = 100
|
| 162 |
+
top_k_indices = logits[0].topk(top_k).indices
|
| 163 |
+
top_k_confidences = logits[0].topk(top_k).values
|
| 164 |
+
|
| 165 |
+
mapper_idx_id = {idx: ids for ids, idx in tokenizer.mapper_idx.items()}
|
| 166 |
+
mapper_names = get_mapper_names(args.mappers_path)
|
| 167 |
+
|
| 168 |
+
for idx, confidence in zip(top_k_indices, top_k_confidences):
|
| 169 |
+
mapper_id = mapper_idx_id[idx.item()]
|
| 170 |
+
mapper_name = mapper_names.get(mapper_id, "Unknown")
|
| 171 |
+
print(f"Mapper: {mapper_name} ({mapper_id}) with confidence: {confidence.item()}")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
main()
|
classifier/configs/inference.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compile: true # PyTorch 2.0 optimization
|
| 2 |
+
device: gpu # Training device (cpu/gpu)
|
| 3 |
+
precision: 'no' # Enable mixed precision (no/fp16/bf16/fp8)
|
| 4 |
+
checkpoint_path: 'OliBomby/osu-classifier' # Project checkpoint directory (to resume training)
|
| 5 |
+
beatmap_path: '' # Path to beatmap to classify
|
| 6 |
+
audio_path: '' # Path to audio to classify
|
| 7 |
+
time: 0 # Time to classify
|
| 8 |
+
mappers_path: './/datasets/beatmap_users.json' # Path to mappers dataset
|
| 9 |
+
|
| 10 |
+
hydra:
|
| 11 |
+
job:
|
| 12 |
+
chdir: False
|
| 13 |
+
run:
|
| 14 |
+
dir: ./logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
classifier/configs/model/model.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
input_features: false
|
| 2 |
+
do_style_embed: true
|
| 3 |
+
classifier_proj_size: 256
|
| 4 |
+
|
| 5 |
+
spectrogram:
|
| 6 |
+
sample_rate: 16000
|
| 7 |
+
hop_length: 128
|
| 8 |
+
n_fft: 1024
|
| 9 |
+
n_mels: 388
|
classifier/configs/model/whisper_base.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- model
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
name: 'openai/whisper-base'
|
| 6 |
+
input_features: true
|
classifier/configs/model/whisper_base_v2.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- model
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
name: 'openai/whisper-base'
|
| 6 |
+
input_features: true
|
| 7 |
+
classifier_proj_size: 2048
|
classifier/configs/model/whisper_small.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- model
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
name: 'openai/whisper-small'
|
| 6 |
+
input_features: true
|
classifier/configs/model/whisper_tiny.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- model
|
| 3 |
+
- _self_
|
| 4 |
+
|
| 5 |
+
name: 'openai/whisper-tiny'
|
| 6 |
+
input_features: true
|
classifier/configs/train.yaml
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
compile: true # PyTorch 2.0 optimization
|
| 2 |
+
device: gpu # Training device (cpu/gpu)
|
| 3 |
+
precision: 'bf16-mixed' # Enable mixed precision (no/fp16/bf16/fp8)
|
| 4 |
+
seed: 42 # Project seed
|
| 5 |
+
|
| 6 |
+
checkpoint_path: '' # Project checkpoint directory (to resume training)
|
| 7 |
+
pretrained_path: '' # Path to pretrained model weights (to do transfer learning)
|
| 8 |
+
|
| 9 |
+
data: # Data settings
|
| 10 |
+
train_dataset_path: "/workspace/datasets/ORS16291"
|
| 11 |
+
test_dataset_path: "/workspace/datasets/ORS16291"
|
| 12 |
+
train_dataset_start: 0 # Training dataset start index
|
| 13 |
+
train_dataset_end: 16200 # Training dataset end index
|
| 14 |
+
test_dataset_start: 16200 # Testing/validation dataset start index
|
| 15 |
+
test_dataset_end: 16291 # Testing/validation dataset end index
|
| 16 |
+
src_seq_len: 1024
|
| 17 |
+
tgt_seq_len: 1024
|
| 18 |
+
sample_rate: ${..model.spectrogram.sample_rate}
|
| 19 |
+
hop_length: ${..model.spectrogram.hop_length}
|
| 20 |
+
cycle_length: 16
|
| 21 |
+
per_track: false # Loads all beatmaps in a track sequentially which optimizes audio data loading
|
| 22 |
+
num_classes: 3731 # Number of label classes in the dataset
|
| 23 |
+
timing_random_offset: 0
|
| 24 |
+
min_difficulty: 0 # Minimum difficulty to consider including in the dataset
|
| 25 |
+
mappers_path: "../../../datasets/beatmap_users.json" # Path to file with all beatmap mappers
|
| 26 |
+
add_timing: true # Model beatmap timing
|
| 27 |
+
add_snapping: true # Model hit object snapping
|
| 28 |
+
add_timing_points: false # Model beatmap timing with timing points
|
| 29 |
+
add_hitsounds: true # Model beatmap hitsounds
|
| 30 |
+
add_distances: false # Model hit object distances
|
| 31 |
+
add_positions: true # Model hit object coordinates
|
| 32 |
+
position_precision: 1 # Precision of hit object coordinates
|
| 33 |
+
position_split_axes: true # Split hit object X and Y coordinates into separate tokens
|
| 34 |
+
position_range: [-256, 768, -256, 640] # Range of hit object coordinates
|
| 35 |
+
dt_augment_prob: 0.7 # Probability of augmenting the dataset with DT
|
| 36 |
+
dt_augment_range: [1.25, 1.5] # Range of DT augmentation
|
| 37 |
+
types_first: true # Put the type token at the start of the group before the timeshift token
|
| 38 |
+
augment_flip: false # Augment the dataset with flipped positions
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
dataloader: # Dataloader settings
|
| 42 |
+
num_workers: 8
|
| 43 |
+
|
| 44 |
+
optim: # Optimizer settings
|
| 45 |
+
name: adamw
|
| 46 |
+
base_lr: 1e-2 # Should be scaled with the number of devices present
|
| 47 |
+
batch_size: 128 # This is the batch size per GPU
|
| 48 |
+
total_steps: 65536
|
| 49 |
+
warmup_steps: 10000
|
| 50 |
+
lr_scheduler: cosine
|
| 51 |
+
weight_decay: 0.0
|
| 52 |
+
grad_clip: 1.0
|
| 53 |
+
grad_acc: 2
|
| 54 |
+
final_cosine: 1e-5
|
| 55 |
+
|
| 56 |
+
eval: # Evaluation settings
|
| 57 |
+
every_steps: 1000
|
| 58 |
+
steps: 500
|
| 59 |
+
|
| 60 |
+
checkpoint: # Checkpoint settings
|
| 61 |
+
every_steps: 5000
|
| 62 |
+
|
| 63 |
+
logging: # Logging settings
|
| 64 |
+
log_with: 'wandb' # Logging service (wandb/tensorboard)
|
| 65 |
+
every_steps: 10
|
| 66 |
+
grad_l2: true
|
| 67 |
+
weights_l2: true
|
| 68 |
+
mode: 'online'
|
| 69 |
+
|
| 70 |
+
profile: # Profiling settings
|
| 71 |
+
do_profile: false
|
| 72 |
+
early_stop: false
|
| 73 |
+
wait: 8
|
| 74 |
+
warmup: 8
|
| 75 |
+
active: 8
|
| 76 |
+
repeat: 1
|
| 77 |
+
|
| 78 |
+
hydra:
|
| 79 |
+
job:
|
| 80 |
+
chdir: True
|
| 81 |
+
run:
|
| 82 |
+
dir: ./logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
classifier/configs/train_v1.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- train
|
| 3 |
+
- _self_
|
| 4 |
+
- model: whisper_tiny
|
classifier/configs/train_v2.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- train
|
| 3 |
+
- _self_
|
| 4 |
+
- model: whisper_base
|
| 5 |
+
|
| 6 |
+
pretrained_path: "../../../test/ckpt_v22"
|
| 7 |
+
|
| 8 |
+
optim: # Optimizer settings
|
| 9 |
+
base_lr: 1e-4 # Should be scaled with the number of devices present
|
| 10 |
+
batch_size: 64 # This is the batch size per GPU
|
| 11 |
+
total_steps: 32218
|
| 12 |
+
warmup_steps: 2000
|
| 13 |
+
grad_acc: 2
|
| 14 |
+
final_cosine: 1e-5
|
classifier/configs/train_v3.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- train
|
| 3 |
+
- _self_
|
| 4 |
+
- model: whisper_base_v2
|
| 5 |
+
|
| 6 |
+
pretrained_path: "../../../test/ckpt_v22"
|
| 7 |
+
|
| 8 |
+
data:
|
| 9 |
+
augment_flip: true
|
| 10 |
+
|
| 11 |
+
optim: # Optimizer settings
|
| 12 |
+
base_lr: 1e-3 # Should be scaled with the number of devices present
|
| 13 |
+
batch_size: 128 # This is the batch size per GPU
|
| 14 |
+
total_steps: 65536
|
| 15 |
+
warmup_steps: 2000
|
| 16 |
+
grad_acc: 4
|
| 17 |
+
final_cosine: 1e-5
|
classifier/count_classes.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def init_mapper_idx(mappers_path):
|
| 6 |
+
""""Indexes beatmap mappers and mapper idx."""
|
| 7 |
+
path = Path(mappers_path)
|
| 8 |
+
|
| 9 |
+
if not path.exists():
|
| 10 |
+
raise ValueError(f"mappers_path {path} not found")
|
| 11 |
+
|
| 12 |
+
# Load JSON data from file
|
| 13 |
+
with open(path, 'r') as file:
|
| 14 |
+
data = json.load(file)
|
| 15 |
+
|
| 16 |
+
# Populate beatmap_mapper
|
| 17 |
+
beatmap_mapper = {}
|
| 18 |
+
for item in data:
|
| 19 |
+
beatmap_mapper[item['id']] = item['user_id']
|
| 20 |
+
|
| 21 |
+
# Get unique user_ids from beatmap_mapper values
|
| 22 |
+
unique_user_ids = list(set(beatmap_mapper.values()))
|
| 23 |
+
|
| 24 |
+
# Create mapper_idx
|
| 25 |
+
mapper_idx = {user_id: idx for idx, user_id in enumerate(unique_user_ids)}
|
| 26 |
+
num_mapper_classes = len(unique_user_ids)
|
| 27 |
+
|
| 28 |
+
return beatmap_mapper, mapper_idx, num_mapper_classes
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
path = "../datasets/beatmap_users.json"
|
| 32 |
+
beatmap_mapper, mapper_idx, num_mapper_classes = init_mapper_idx(path)
|
| 33 |
+
|
| 34 |
+
print("Number of mapper classes:", num_mapper_classes)
|
| 35 |
+
print("Number of beatmaps:", len(beatmap_mapper))
|
| 36 |
+
# Calculate number of maps per mapper
|
| 37 |
+
maps_per_mapper = {}
|
| 38 |
+
for beatmap_id in beatmap_mapper:
|
| 39 |
+
user_id = beatmap_mapper[beatmap_id]
|
| 40 |
+
if user_id not in maps_per_mapper:
|
| 41 |
+
maps_per_mapper[user_id] = 0
|
| 42 |
+
maps_per_mapper[user_id] += 1
|
| 43 |
+
|
| 44 |
+
# Calculate average maps per mapper class
|
| 45 |
+
average_maps_per_mapper = len(beatmap_mapper) / num_mapper_classes
|
| 46 |
+
print("Average maps per mapper class:", average_maps_per_mapper)
|
| 47 |
+
|
| 48 |
+
# Calculate median maps per mapper class
|
| 49 |
+
median_maps_per_mapper = sorted(maps_per_mapper.values())[num_mapper_classes // 2]
|
| 50 |
+
print("Median maps per mapper class:", median_maps_per_mapper)
|
| 51 |
+
|
| 52 |
+
# Mapper with most number of maps
|
| 53 |
+
max_maps = max(maps_per_mapper.values())
|
| 54 |
+
max_maps_mapper = [user_id for user_id in maps_per_mapper if maps_per_mapper[user_id] == max_maps]
|
| 55 |
+
print("Mapper with most number of maps:", max_maps_mapper)
|
| 56 |
+
print("Number of maps:", max_maps)
|
classifier/libs/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .utils.model_utils import get_dataloaders, get_optimizer, get_scheduler, get_tokenizer
|
classifier/libs/dataset/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .ors_dataset import OrsDataset
|
| 2 |
+
from .osu_parser import OsuParser
|
| 3 |
+
from .data_utils import update_event_times
|
classifier/libs/dataset/data_utils.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from pydub import AudioSegment
|
| 7 |
+
|
| 8 |
+
import numpy.typing as npt
|
| 9 |
+
|
| 10 |
+
from ..tokenizer import Event, EventType
|
| 11 |
+
|
| 12 |
+
MILISECONDS_PER_SECOND = 1000
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_audio_file(file: Path, sample_rate: int, speed: float = 1.0) -> npt.NDArray:
|
| 16 |
+
"""Load an audio file as a numpy time-series array
|
| 17 |
+
|
| 18 |
+
The signals are resampled, converted to mono channel, and normalized.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
file: Path to audio file.
|
| 22 |
+
sample_rate: Sample rate to resample the audio.
|
| 23 |
+
speed: Speed multiplier for the audio.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
samples: Audio time series.
|
| 27 |
+
"""
|
| 28 |
+
file = Path(file)
|
| 29 |
+
audio = AudioSegment.from_file(file, format=file.suffix[1:])
|
| 30 |
+
audio.frame_rate = int(audio.frame_rate * speed)
|
| 31 |
+
audio = audio.set_frame_rate(sample_rate)
|
| 32 |
+
audio = audio.set_channels(1)
|
| 33 |
+
samples = np.array(audio.get_array_of_samples()).astype(np.float32)
|
| 34 |
+
samples *= 1.0 / np.max(np.abs(samples))
|
| 35 |
+
return samples
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def update_event_times(
|
| 39 |
+
events: list[Event],
|
| 40 |
+
event_times: list[int],
|
| 41 |
+
end_time: Optional[float] = None,
|
| 42 |
+
types_first: bool = False
|
| 43 |
+
) -> None:
|
| 44 |
+
"""Extends the event times list with the times of the new events if the event list is longer than the event times list.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
events: List of events.
|
| 48 |
+
event_times: List of event times.
|
| 49 |
+
end_time: End time of the events, for interpolation.
|
| 50 |
+
types_first: If True, the type token is at the start of the group before the timeshift token.
|
| 51 |
+
"""
|
| 52 |
+
non_timed_events = [
|
| 53 |
+
EventType.BEZIER_ANCHOR,
|
| 54 |
+
EventType.PERFECT_ANCHOR,
|
| 55 |
+
EventType.CATMULL_ANCHOR,
|
| 56 |
+
EventType.RED_ANCHOR,
|
| 57 |
+
]
|
| 58 |
+
timed_events = [
|
| 59 |
+
EventType.CIRCLE,
|
| 60 |
+
EventType.SPINNER,
|
| 61 |
+
EventType.SPINNER_END,
|
| 62 |
+
EventType.SLIDER_HEAD,
|
| 63 |
+
EventType.LAST_ANCHOR,
|
| 64 |
+
EventType.SLIDER_END,
|
| 65 |
+
EventType.BEAT,
|
| 66 |
+
EventType.MEASURE,
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
start_index = len(event_times)
|
| 70 |
+
end_index = len(events)
|
| 71 |
+
current_time = 0 if len(event_times) == 0 else event_times[-1]
|
| 72 |
+
for i in range(start_index, end_index):
|
| 73 |
+
if types_first:
|
| 74 |
+
if i + 1 < end_index and events[i + 1].type == EventType.TIME_SHIFT:
|
| 75 |
+
current_time = events[i + 1].value
|
| 76 |
+
elif events[i].type == EventType.TIME_SHIFT:
|
| 77 |
+
current_time = events[i].value
|
| 78 |
+
event_times.append(current_time)
|
| 79 |
+
|
| 80 |
+
# Interpolate time for control point events
|
| 81 |
+
interpolate = False
|
| 82 |
+
if types_first:
|
| 83 |
+
# Start-T-D-CP-D-CP-D-LCP-T-D-End-T-D
|
| 84 |
+
# 1-----1-1-1--1-1--1-7---7-7-9---9-9
|
| 85 |
+
# 1-----1-1-3--3-5--5-7---7-7-9---9-9
|
| 86 |
+
index = range(start_index, end_index)
|
| 87 |
+
current_time = 0 if len(event_times) == 0 else event_times[-1]
|
| 88 |
+
else:
|
| 89 |
+
# T-D-Start-D-CP-D-CP-T-D-LCP-T-D-End
|
| 90 |
+
# 1-1-1-----1-1--1-1--7-7--7--9-9-9--
|
| 91 |
+
# 1-1-1-----3-3--5-5--7-7--7--9-9-9--
|
| 92 |
+
index = range(end_index - 1, start_index - 1, -1)
|
| 93 |
+
current_time = end_time if end_time is not None else event_times[-1]
|
| 94 |
+
for i in index:
|
| 95 |
+
event = events[i]
|
| 96 |
+
|
| 97 |
+
if event.type in timed_events:
|
| 98 |
+
interpolate = False
|
| 99 |
+
|
| 100 |
+
if event.type in non_timed_events:
|
| 101 |
+
interpolate = True
|
| 102 |
+
|
| 103 |
+
if not interpolate:
|
| 104 |
+
current_time = event_times[i]
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
if event.type not in non_timed_events:
|
| 108 |
+
event_times[i] = current_time
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
# Find the time of the first timed event and the number of control points between
|
| 112 |
+
j = i
|
| 113 |
+
step = 1 if types_first else -1
|
| 114 |
+
count = 0
|
| 115 |
+
other_time = current_time
|
| 116 |
+
while 0 <= j < len(events):
|
| 117 |
+
event2 = events[j]
|
| 118 |
+
if event2.type == EventType.TIME_SHIFT:
|
| 119 |
+
other_time = event_times[j]
|
| 120 |
+
break
|
| 121 |
+
if event2.type in non_timed_events:
|
| 122 |
+
count += 1
|
| 123 |
+
j += step
|
| 124 |
+
if j < 0:
|
| 125 |
+
other_time = 0
|
| 126 |
+
if j >= len(events):
|
| 127 |
+
other_time = end_time if end_time is not None else event_times[-1]
|
| 128 |
+
|
| 129 |
+
# Interpolate the time
|
| 130 |
+
current_time = int((current_time - other_time) / (count + 1) * count + other_time)
|
| 131 |
+
event_times[i] = current_time
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def merge_events(events1: list[Event], event_times1: list[int], events2: list[Event], event_times2: list[int]) -> tuple[list[Event], list[int]]:
|
| 135 |
+
"""Merge two lists of events in a time sorted manner. Assumes both lists are sorted by time.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
events1: List of events.
|
| 139 |
+
event_times1: List of event times.
|
| 140 |
+
events2: List of events.
|
| 141 |
+
event_times2: List of event times.
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
merged_events: Merged list of events.
|
| 145 |
+
merged_event_times: Merged list of event times.
|
| 146 |
+
"""
|
| 147 |
+
merged_events = []
|
| 148 |
+
merged_event_times = []
|
| 149 |
+
i = 0
|
| 150 |
+
j = 0
|
| 151 |
+
|
| 152 |
+
while i < len(events1) and j < len(events2):
|
| 153 |
+
t1 = event_times1[i]
|
| 154 |
+
t2 = event_times2[j]
|
| 155 |
+
|
| 156 |
+
if t1 <= t2:
|
| 157 |
+
merged_events.append(events1[i])
|
| 158 |
+
merged_event_times.append(t1)
|
| 159 |
+
i += 1
|
| 160 |
+
else:
|
| 161 |
+
merged_events.append(events2[j])
|
| 162 |
+
merged_event_times.append(t2)
|
| 163 |
+
j += 1
|
| 164 |
+
|
| 165 |
+
merged_events.extend(events1[i:])
|
| 166 |
+
merged_events.extend(events2[j:])
|
| 167 |
+
merged_event_times.extend(event_times1[i:])
|
| 168 |
+
merged_event_times.extend(event_times2[j:])
|
| 169 |
+
return merged_events, merged_event_times
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def remove_events_of_type(events: list[Event], event_times: list[int], event_types: list[EventType]) -> tuple[list[Event], list[int]]:
|
| 173 |
+
"""Remove all events of a specific type from a list of events.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
events: List of events.
|
| 177 |
+
event_types: Types of event to remove.
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
filtered_events: Filtered list of events.
|
| 181 |
+
"""
|
| 182 |
+
new_events = []
|
| 183 |
+
new_event_times = []
|
| 184 |
+
for event, time in zip(events, event_times):
|
| 185 |
+
if event.type not in event_types:
|
| 186 |
+
new_events.append(event)
|
| 187 |
+
new_event_times.append(time)
|
| 188 |
+
return new_events, new_event_times
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def speed_events(events: list[Event], event_times: list[int], speed: float) -> tuple[list[Event], list[int]]:
|
| 192 |
+
"""Change the speed of a list of events.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
events: List of events.
|
| 196 |
+
event_times: List of event times
|
| 197 |
+
speed: Speed multiplier.
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
sped_events: Sped up list of events.
|
| 201 |
+
"""
|
| 202 |
+
sped_events = []
|
| 203 |
+
for event in events:
|
| 204 |
+
if event.type == EventType.TIME_SHIFT:
|
| 205 |
+
event.value = int(event.value / speed)
|
| 206 |
+
sped_events.append(event)
|
| 207 |
+
|
| 208 |
+
sped_event_times = []
|
| 209 |
+
for t in event_times:
|
| 210 |
+
sped_event_times.append(int(t / speed))
|
| 211 |
+
|
| 212 |
+
return sped_events, sped_event_times
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
@dataclasses.dataclass
|
| 216 |
+
class Group:
|
| 217 |
+
event_type: EventType = None
|
| 218 |
+
time: int = 0
|
| 219 |
+
distance: int = None
|
| 220 |
+
x: float = None
|
| 221 |
+
y: float = None
|
| 222 |
+
new_combo: bool = False
|
| 223 |
+
hitsounds: list[int] = dataclasses.field(default_factory=list)
|
| 224 |
+
samplesets: list[int] = dataclasses.field(default_factory=list)
|
| 225 |
+
additions: list[int] = dataclasses.field(default_factory=list)
|
| 226 |
+
volumes: list[int] = dataclasses.field(default_factory=list)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
type_events = [
|
| 230 |
+
EventType.CIRCLE,
|
| 231 |
+
EventType.SPINNER,
|
| 232 |
+
EventType.SPINNER_END,
|
| 233 |
+
EventType.SLIDER_HEAD,
|
| 234 |
+
EventType.BEZIER_ANCHOR,
|
| 235 |
+
EventType.PERFECT_ANCHOR,
|
| 236 |
+
EventType.CATMULL_ANCHOR,
|
| 237 |
+
EventType.RED_ANCHOR,
|
| 238 |
+
EventType.LAST_ANCHOR,
|
| 239 |
+
EventType.SLIDER_END,
|
| 240 |
+
EventType.BEAT,
|
| 241 |
+
EventType.MEASURE,
|
| 242 |
+
]
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def get_groups(
|
| 246 |
+
events: list[Event],
|
| 247 |
+
*,
|
| 248 |
+
event_times: Optional[list[int]] = None,
|
| 249 |
+
types_first: bool = False
|
| 250 |
+
) -> list[Group]:
|
| 251 |
+
groups = []
|
| 252 |
+
group = Group()
|
| 253 |
+
for i, event in enumerate(events):
|
| 254 |
+
if event.type == EventType.TIME_SHIFT:
|
| 255 |
+
group.time = event.value
|
| 256 |
+
elif event.type == EventType.DISTANCE:
|
| 257 |
+
group.distance = event.value
|
| 258 |
+
elif event.type == EventType.POS_X:
|
| 259 |
+
group.x = event.value
|
| 260 |
+
elif event.type == EventType.POS_Y:
|
| 261 |
+
group.y = event.value
|
| 262 |
+
elif event.type == EventType.NEW_COMBO:
|
| 263 |
+
group.new_combo = True
|
| 264 |
+
elif event.type == EventType.HITSOUND:
|
| 265 |
+
group.hitsounds.append((event.value % 8) * 2)
|
| 266 |
+
group.samplesets.append(((event.value // 8) % 3) + 1)
|
| 267 |
+
group.additions.append(((event.value // 24) % 3) + 1)
|
| 268 |
+
elif event.type == EventType.VOLUME:
|
| 269 |
+
group.volumes.append(event.value)
|
| 270 |
+
elif event.type in type_events:
|
| 271 |
+
if types_first:
|
| 272 |
+
if group.event_type is not None:
|
| 273 |
+
groups.append(group)
|
| 274 |
+
group = Group()
|
| 275 |
+
group.event_type = event.type
|
| 276 |
+
if event_times is not None:
|
| 277 |
+
group.time = event_times[i]
|
| 278 |
+
else:
|
| 279 |
+
group.event_type = event.type
|
| 280 |
+
if event_times is not None:
|
| 281 |
+
group.time = event_times[i]
|
| 282 |
+
groups.append(group)
|
| 283 |
+
group = Group()
|
| 284 |
+
|
| 285 |
+
if group.event_type is not None:
|
| 286 |
+
groups.append(group)
|
| 287 |
+
|
| 288 |
+
return groups
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def get_group_indices(events: list[Event], types_first: bool = False) -> list[list[int]]:
|
| 292 |
+
groups = []
|
| 293 |
+
indices = []
|
| 294 |
+
for i, event in enumerate(events):
|
| 295 |
+
indices.append(i)
|
| 296 |
+
if event.type in type_events:
|
| 297 |
+
if types_first:
|
| 298 |
+
if len(indices) > 1:
|
| 299 |
+
groups.append(indices[:-1])
|
| 300 |
+
indices = [indices[-1]]
|
| 301 |
+
else:
|
| 302 |
+
groups.append(indices)
|
| 303 |
+
indices = []
|
| 304 |
+
|
| 305 |
+
if len(indices) > 0:
|
| 306 |
+
groups.append(indices)
|
| 307 |
+
|
| 308 |
+
return groups
|
classifier/libs/dataset/ors_dataset.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
from typing import Optional, Callable
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import numpy.typing as npt
|
| 11 |
+
import torch
|
| 12 |
+
from omegaconf import DictConfig
|
| 13 |
+
from slider import Beatmap
|
| 14 |
+
from torch.utils.data import IterableDataset
|
| 15 |
+
|
| 16 |
+
from .data_utils import load_audio_file
|
| 17 |
+
from .osu_parser import OsuParser
|
| 18 |
+
from ..tokenizer import Event, EventType, Tokenizer
|
| 19 |
+
|
| 20 |
+
OSZ_FILE_EXTENSION = ".osz"
|
| 21 |
+
AUDIO_FILE_NAME = "audio.mp3"
|
| 22 |
+
MILISECONDS_PER_SECOND = 1000
|
| 23 |
+
STEPS_PER_MILLISECOND = 0.1
|
| 24 |
+
LABEL_IGNORE_ID = -100
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class OrsDataset(IterableDataset):
|
| 28 |
+
__slots__ = (
|
| 29 |
+
"path",
|
| 30 |
+
"start",
|
| 31 |
+
"end",
|
| 32 |
+
"args",
|
| 33 |
+
"parser",
|
| 34 |
+
"tokenizer",
|
| 35 |
+
"beatmap_files",
|
| 36 |
+
"test",
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
args: DictConfig,
|
| 42 |
+
parser: OsuParser,
|
| 43 |
+
tokenizer: Tokenizer,
|
| 44 |
+
beatmap_files: Optional[list[Path]] = None,
|
| 45 |
+
test: bool = False,
|
| 46 |
+
):
|
| 47 |
+
"""Manage and process ORS dataset.
|
| 48 |
+
|
| 49 |
+
Attributes:
|
| 50 |
+
args: Data loading arguments.
|
| 51 |
+
parser: Instance of OsuParser class.
|
| 52 |
+
tokenizer: Instance of Tokenizer class.
|
| 53 |
+
beatmap_files: List of beatmap files to process. Overrides track index range.
|
| 54 |
+
test: Whether to load the test dataset.
|
| 55 |
+
"""
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.path = args.test_dataset_path if test else args.train_dataset_path
|
| 58 |
+
self.start = args.test_dataset_start if test else args.train_dataset_start
|
| 59 |
+
self.end = args.test_dataset_end if test else args.train_dataset_end
|
| 60 |
+
self.args = args
|
| 61 |
+
self.parser = parser
|
| 62 |
+
self.tokenizer = tokenizer
|
| 63 |
+
self.beatmap_files = beatmap_files
|
| 64 |
+
self.test = test
|
| 65 |
+
|
| 66 |
+
def _get_beatmap_files(self) -> list[Path]:
|
| 67 |
+
if self.beatmap_files is not None:
|
| 68 |
+
return self.beatmap_files
|
| 69 |
+
|
| 70 |
+
# Get a list of all beatmap files in the dataset path in the track index range between start and end
|
| 71 |
+
beatmap_files = []
|
| 72 |
+
track_names = ["Track" + str(i).zfill(5) for i in range(self.start, self.end)]
|
| 73 |
+
for track_name in track_names:
|
| 74 |
+
for beatmap_file in os.listdir(
|
| 75 |
+
os.path.join(self.path, track_name, "beatmaps"),
|
| 76 |
+
):
|
| 77 |
+
beatmap_files.append(
|
| 78 |
+
Path(
|
| 79 |
+
os.path.join(
|
| 80 |
+
self.path,
|
| 81 |
+
track_name,
|
| 82 |
+
"beatmaps",
|
| 83 |
+
beatmap_file,
|
| 84 |
+
)
|
| 85 |
+
),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
return beatmap_files
|
| 89 |
+
|
| 90 |
+
def _get_track_paths(self) -> list[Path]:
|
| 91 |
+
track_paths = []
|
| 92 |
+
track_names = ["Track" + str(i).zfill(5) for i in range(self.start, self.end)]
|
| 93 |
+
for track_name in track_names:
|
| 94 |
+
track_paths.append(Path(os.path.join(self.path, track_name)))
|
| 95 |
+
return track_paths
|
| 96 |
+
|
| 97 |
+
def __iter__(self):
|
| 98 |
+
beatmap_files = self._get_track_paths() if self.args.per_track else self._get_beatmap_files()
|
| 99 |
+
|
| 100 |
+
if not self.test:
|
| 101 |
+
random.shuffle(beatmap_files)
|
| 102 |
+
|
| 103 |
+
if self.args.cycle_length > 1 and not self.test:
|
| 104 |
+
return InterleavingBeatmapDatasetIterable(
|
| 105 |
+
beatmap_files,
|
| 106 |
+
self._iterable_factory,
|
| 107 |
+
self.args.cycle_length,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
return self._iterable_factory(beatmap_files).__iter__()
|
| 111 |
+
|
| 112 |
+
def _iterable_factory(self, beatmap_files: list[Path]):
|
| 113 |
+
return BeatmapDatasetIterable(
|
| 114 |
+
beatmap_files,
|
| 115 |
+
self.args,
|
| 116 |
+
self.parser,
|
| 117 |
+
self.tokenizer,
|
| 118 |
+
self.test,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class InterleavingBeatmapDatasetIterable:
|
| 123 |
+
__slots__ = ("workers", "cycle_length", "index")
|
| 124 |
+
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
beatmap_files: list[Path],
|
| 128 |
+
iterable_factory: Callable,
|
| 129 |
+
cycle_length: int,
|
| 130 |
+
):
|
| 131 |
+
per_worker = int(np.ceil(len(beatmap_files) / float(cycle_length)))
|
| 132 |
+
self.workers = [
|
| 133 |
+
iterable_factory(
|
| 134 |
+
beatmap_files[
|
| 135 |
+
i * per_worker: min(len(beatmap_files), (i + 1) * per_worker)
|
| 136 |
+
]
|
| 137 |
+
).__iter__()
|
| 138 |
+
for i in range(cycle_length)
|
| 139 |
+
]
|
| 140 |
+
self.cycle_length = cycle_length
|
| 141 |
+
self.index = 0
|
| 142 |
+
|
| 143 |
+
def __iter__(self) -> "InterleavingBeatmapDatasetIterable":
|
| 144 |
+
return self
|
| 145 |
+
|
| 146 |
+
def __next__(self) -> tuple[any, int]:
|
| 147 |
+
num = len(self.workers)
|
| 148 |
+
for _ in range(num):
|
| 149 |
+
try:
|
| 150 |
+
self.index = self.index % len(self.workers)
|
| 151 |
+
item = self.workers[self.index].__next__()
|
| 152 |
+
self.index += 1
|
| 153 |
+
return item
|
| 154 |
+
except StopIteration:
|
| 155 |
+
self.workers.remove(self.workers[self.index])
|
| 156 |
+
raise StopIteration
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class BeatmapDatasetIterable:
|
| 160 |
+
__slots__ = (
|
| 161 |
+
"beatmap_files",
|
| 162 |
+
"args",
|
| 163 |
+
"parser",
|
| 164 |
+
"tokenizer",
|
| 165 |
+
"test",
|
| 166 |
+
"frame_seq_len",
|
| 167 |
+
"pre_token_len",
|
| 168 |
+
"add_empty_sequences",
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
beatmap_files: list[Path],
|
| 174 |
+
args: DictConfig,
|
| 175 |
+
parser: OsuParser,
|
| 176 |
+
tokenizer: Tokenizer,
|
| 177 |
+
test: bool,
|
| 178 |
+
):
|
| 179 |
+
self.beatmap_files = beatmap_files
|
| 180 |
+
self.args = args
|
| 181 |
+
self.parser = parser
|
| 182 |
+
self.tokenizer = tokenizer
|
| 183 |
+
self.test = test
|
| 184 |
+
self.frame_seq_len = args.src_seq_len - 1
|
| 185 |
+
|
| 186 |
+
def _get_frames(self, samples: npt.NDArray) -> tuple[npt.NDArray, npt.NDArray]:
|
| 187 |
+
"""Segment audio samples into frames.
|
| 188 |
+
|
| 189 |
+
Each frame has `frame_size` audio samples.
|
| 190 |
+
It will also calculate and return the time of each audio frame, in miliseconds.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
samples: Audio time-series.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
frames: Audio frames.
|
| 197 |
+
frame_times: Audio frame times.
|
| 198 |
+
"""
|
| 199 |
+
samples = np.pad(samples, [0, self.args.hop_length - len(samples) % self.args.hop_length])
|
| 200 |
+
frames = np.reshape(samples, (-1, self.args.hop_length))
|
| 201 |
+
frames_per_milisecond = (
|
| 202 |
+
self.args.sample_rate / self.args.hop_length / MILISECONDS_PER_SECOND
|
| 203 |
+
)
|
| 204 |
+
frame_times = np.arange(len(frames)) / frames_per_milisecond
|
| 205 |
+
return frames, frame_times
|
| 206 |
+
|
| 207 |
+
def _create_sequences(
|
| 208 |
+
self,
|
| 209 |
+
frames: npt.NDArray,
|
| 210 |
+
frame_times: npt.NDArray,
|
| 211 |
+
context: dict,
|
| 212 |
+
extra_data: Optional[dict] = None,
|
| 213 |
+
) -> list[dict[str, int | npt.NDArray | list[Event]]]:
|
| 214 |
+
"""Create frame and token sequences for training/testing.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
frames: Audio frames.
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
A list of source and target sequences.
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
def get_event_indices(events2: list[Event], event_times2: list[int]) -> tuple[list[int], list[int]]:
|
| 224 |
+
if len(events2) == 0:
|
| 225 |
+
return [], []
|
| 226 |
+
|
| 227 |
+
# Corresponding start event index for every audio frame.
|
| 228 |
+
start_indices = []
|
| 229 |
+
event_index = 0
|
| 230 |
+
|
| 231 |
+
for current_time in frame_times:
|
| 232 |
+
while event_index < len(events2) and event_times2[event_index] < current_time:
|
| 233 |
+
event_index += 1
|
| 234 |
+
start_indices.append(event_index)
|
| 235 |
+
|
| 236 |
+
# Corresponding end event index for every audio frame.
|
| 237 |
+
end_indices = start_indices[1:] + [len(events2)]
|
| 238 |
+
|
| 239 |
+
return start_indices, end_indices
|
| 240 |
+
|
| 241 |
+
start_indices, end_indices = get_event_indices(context["events"], context["event_times"])
|
| 242 |
+
|
| 243 |
+
sequences = []
|
| 244 |
+
n_frames = len(frames)
|
| 245 |
+
offset = random.randint(0, self.frame_seq_len)
|
| 246 |
+
# Divide audio frames into splits
|
| 247 |
+
for frame_start_idx in range(offset, n_frames, self.frame_seq_len):
|
| 248 |
+
frame_end_idx = min(frame_start_idx + self.frame_seq_len, n_frames)
|
| 249 |
+
|
| 250 |
+
def slice_events(context, frame_start_idx, frame_end_idx):
|
| 251 |
+
if len(context["events"]) == 0:
|
| 252 |
+
return []
|
| 253 |
+
event_start_idx = start_indices[frame_start_idx]
|
| 254 |
+
event_end_idx = end_indices[frame_end_idx - 1]
|
| 255 |
+
return context["events"][event_start_idx:event_end_idx]
|
| 256 |
+
|
| 257 |
+
def slice_context(context, frame_start_idx, frame_end_idx):
|
| 258 |
+
return {"events": slice_events(context, frame_start_idx, frame_end_idx)}
|
| 259 |
+
|
| 260 |
+
# Create the sequence
|
| 261 |
+
sequence = {
|
| 262 |
+
"time": frame_times[frame_start_idx],
|
| 263 |
+
"frames": frames[frame_start_idx:frame_end_idx],
|
| 264 |
+
"context": slice_context(context, frame_start_idx, frame_end_idx),
|
| 265 |
+
} | extra_data
|
| 266 |
+
|
| 267 |
+
sequences.append(sequence)
|
| 268 |
+
|
| 269 |
+
return sequences
|
| 270 |
+
|
| 271 |
+
def _normalize_time_shifts(self, sequence: dict) -> dict:
|
| 272 |
+
"""Make all time shifts in the sequence relative to the start time of the sequence,
|
| 273 |
+
and normalize time values.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
sequence: The input sequence.
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
The same sequence with trimmed time shifts.
|
| 280 |
+
"""
|
| 281 |
+
|
| 282 |
+
def process(events: list[Event], start_time) -> list[Event] | tuple[list[Event], int]:
|
| 283 |
+
for i, event in enumerate(events):
|
| 284 |
+
if event.type == EventType.TIME_SHIFT:
|
| 285 |
+
# We cant modify the event objects themselves because that will affect subsequent sequences
|
| 286 |
+
events[i] = Event(EventType.TIME_SHIFT, int((event.value - start_time) * STEPS_PER_MILLISECOND))
|
| 287 |
+
|
| 288 |
+
return events
|
| 289 |
+
|
| 290 |
+
start_time = sequence["time"]
|
| 291 |
+
del sequence["time"]
|
| 292 |
+
|
| 293 |
+
sequence["context"]["events"] = process(sequence["context"]["events"], start_time)
|
| 294 |
+
|
| 295 |
+
return sequence
|
| 296 |
+
|
| 297 |
+
def _tokenize_sequence(self, sequence: dict) -> dict:
|
| 298 |
+
"""Tokenize the event sequence.
|
| 299 |
+
|
| 300 |
+
Begin token sequence with `[SOS]` token (start-of-sequence).
|
| 301 |
+
End token sequence with `[EOS]` token (end-of-sequence).
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
sequence: The input sequence.
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
The same sequence with tokenized events.
|
| 308 |
+
"""
|
| 309 |
+
context = sequence["context"]
|
| 310 |
+
tokens = torch.empty(len(context["events"]), dtype=torch.long)
|
| 311 |
+
for i, event in enumerate(context["events"]):
|
| 312 |
+
tokens[i] = self.tokenizer.encode(event)
|
| 313 |
+
context["tokens"] = tokens
|
| 314 |
+
|
| 315 |
+
return sequence
|
| 316 |
+
|
| 317 |
+
def _pad_and_split_token_sequence(self, sequence: dict) -> dict:
|
| 318 |
+
"""Pad token sequence to a fixed length and split decoder input and labels.
|
| 319 |
+
|
| 320 |
+
Pad with `[PAD]` tokens until `tgt_seq_len`.
|
| 321 |
+
|
| 322 |
+
Token sequence (w/o last token) is the input to the transformer decoder,
|
| 323 |
+
token sequence (w/o first token) is the label, a.k.a. decoder ground truth.
|
| 324 |
+
|
| 325 |
+
Prefix the token sequence with the pre_tokens sequence.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
sequence: The input sequence.
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
The same sequence with padded tokens.
|
| 332 |
+
"""
|
| 333 |
+
# Count reducible tokens, pre_tokens and context tokens
|
| 334 |
+
num_tokens = len(sequence["context"]["tokens"])
|
| 335 |
+
|
| 336 |
+
# Trim tokens to target sequence length
|
| 337 |
+
# n + padding = tgt_seq_len
|
| 338 |
+
n = min(self.args.tgt_seq_len, num_tokens)
|
| 339 |
+
si = 0
|
| 340 |
+
|
| 341 |
+
input_tokens = torch.full((self.args.tgt_seq_len,), self.tokenizer.pad_id, dtype=torch.long)
|
| 342 |
+
|
| 343 |
+
tokens = sequence["context"]["tokens"]
|
| 344 |
+
|
| 345 |
+
input_tokens[si:si + n] = tokens[:n]
|
| 346 |
+
|
| 347 |
+
# Randomize some input tokens
|
| 348 |
+
def randomize_tokens(tokens):
|
| 349 |
+
offset = torch.randint(low=-self.args.timing_random_offset, high=self.args.timing_random_offset + 1,
|
| 350 |
+
size=tokens.shape)
|
| 351 |
+
return torch.where((self.tokenizer.event_start[EventType.TIME_SHIFT] <= tokens) & (
|
| 352 |
+
tokens < self.tokenizer.event_end[EventType.TIME_SHIFT]),
|
| 353 |
+
torch.clamp(tokens + offset,
|
| 354 |
+
self.tokenizer.event_start[EventType.TIME_SHIFT],
|
| 355 |
+
self.tokenizer.event_end[EventType.TIME_SHIFT] - 1),
|
| 356 |
+
tokens)
|
| 357 |
+
|
| 358 |
+
if self.args.timing_random_offset > 0:
|
| 359 |
+
input_tokens[si:si + n] = randomize_tokens(input_tokens[si:si + n])
|
| 360 |
+
|
| 361 |
+
sequence["decoder_input_ids"] = input_tokens
|
| 362 |
+
sequence["decoder_attention_mask"] = input_tokens != self.tokenizer.pad_id
|
| 363 |
+
|
| 364 |
+
del sequence["context"]
|
| 365 |
+
|
| 366 |
+
return sequence
|
| 367 |
+
|
| 368 |
+
def _pad_frame_sequence(self, sequence: dict) -> dict:
|
| 369 |
+
"""Pad frame sequence with zeros until `frame_seq_len`.
|
| 370 |
+
|
| 371 |
+
Frame sequence can be further processed into Mel spectrogram frames,
|
| 372 |
+
which is the input to the transformer encoder.
|
| 373 |
+
|
| 374 |
+
Args:
|
| 375 |
+
sequence: The input sequence.
|
| 376 |
+
|
| 377 |
+
Returns:
|
| 378 |
+
The same sequence with padded frames.
|
| 379 |
+
"""
|
| 380 |
+
frames = torch.from_numpy(sequence["frames"]).to(torch.float32)
|
| 381 |
+
|
| 382 |
+
if frames.shape[0] != self.frame_seq_len:
|
| 383 |
+
n = min(self.frame_seq_len, len(frames))
|
| 384 |
+
padded_frames = torch.zeros(
|
| 385 |
+
self.frame_seq_len,
|
| 386 |
+
frames.shape[-1],
|
| 387 |
+
dtype=frames.dtype,
|
| 388 |
+
device=frames.device,
|
| 389 |
+
)
|
| 390 |
+
padded_frames[:n] = frames[:n]
|
| 391 |
+
sequence["frames"] = torch.flatten(padded_frames)
|
| 392 |
+
else:
|
| 393 |
+
sequence["frames"] = torch.flatten(frames)
|
| 394 |
+
|
| 395 |
+
return sequence
|
| 396 |
+
|
| 397 |
+
def __iter__(self):
|
| 398 |
+
return self._get_next_tracks() if self.args.per_track else self._get_next_beatmaps()
|
| 399 |
+
|
| 400 |
+
@staticmethod
|
| 401 |
+
def _load_metadata(track_path: Path) -> dict:
|
| 402 |
+
metadata_file = track_path / "metadata.json"
|
| 403 |
+
with open(metadata_file) as f:
|
| 404 |
+
return json.load(f)
|
| 405 |
+
|
| 406 |
+
def _get_difficulty(self, metadata: dict, beatmap_name: str, speed: float = 1.0, beatmap: Beatmap = None) -> float:
|
| 407 |
+
if beatmap is not None and (all(e == 1.5 for e in self.args.dt_augment_range) or speed not in [1.0, 1.5]):
|
| 408 |
+
return beatmap.stars(speed_scale=speed)
|
| 409 |
+
|
| 410 |
+
if speed == 1.5:
|
| 411 |
+
return metadata["Beatmaps"][beatmap_name]["StandardStarRating"]["64"]
|
| 412 |
+
return metadata["Beatmaps"][beatmap_name]["StandardStarRating"]["0"]
|
| 413 |
+
|
| 414 |
+
@staticmethod
|
| 415 |
+
def _get_idx(metadata: dict, beatmap_name: str):
|
| 416 |
+
return metadata["Beatmaps"][beatmap_name]["Index"]
|
| 417 |
+
|
| 418 |
+
def _get_speed_augment(self):
|
| 419 |
+
mi, ma = self.args.dt_augment_range
|
| 420 |
+
return random.random() * (ma - mi) + mi if random.random() < self.args.dt_augment_prob else 1.0
|
| 421 |
+
|
| 422 |
+
def _get_next_beatmaps(self) -> dict:
|
| 423 |
+
for beatmap_path in self.beatmap_files:
|
| 424 |
+
metadata = self._load_metadata(beatmap_path.parents[1])
|
| 425 |
+
|
| 426 |
+
if self.args.min_difficulty > 0 and self._get_difficulty(metadata,
|
| 427 |
+
beatmap_path.stem) < self.args.min_difficulty:
|
| 428 |
+
continue
|
| 429 |
+
|
| 430 |
+
speed = self._get_speed_augment()
|
| 431 |
+
audio_path = beatmap_path.parents[1] / list(beatmap_path.parents[1].glob('audio.*'))[0]
|
| 432 |
+
audio_samples = load_audio_file(audio_path, self.args.sample_rate, speed)
|
| 433 |
+
|
| 434 |
+
for sample in self._get_next_beatmap(audio_samples, beatmap_path, speed):
|
| 435 |
+
yield sample
|
| 436 |
+
|
| 437 |
+
def _get_next_tracks(self) -> dict:
|
| 438 |
+
for track_path in self.beatmap_files:
|
| 439 |
+
metadata = self._load_metadata(track_path)
|
| 440 |
+
|
| 441 |
+
if self.args.min_difficulty > 0 and all(self._get_difficulty(metadata, beatmap_name)
|
| 442 |
+
< self.args.min_difficulty for beatmap_name in
|
| 443 |
+
metadata["Beatmaps"]):
|
| 444 |
+
continue
|
| 445 |
+
|
| 446 |
+
speed = self._get_speed_augment()
|
| 447 |
+
audio_path = track_path / list(track_path.glob('audio.*'))[0]
|
| 448 |
+
audio_samples = load_audio_file(audio_path, self.args.sample_rate, speed)
|
| 449 |
+
|
| 450 |
+
for beatmap_name in metadata["Beatmaps"]:
|
| 451 |
+
beatmap_path = (track_path / "beatmaps" / beatmap_name).with_suffix(".osu")
|
| 452 |
+
|
| 453 |
+
if self.args.min_difficulty > 0 and self._get_difficulty(metadata,
|
| 454 |
+
beatmap_name) < self.args.min_difficulty:
|
| 455 |
+
continue
|
| 456 |
+
|
| 457 |
+
for sample in self._get_next_beatmap(audio_samples, beatmap_path, speed):
|
| 458 |
+
yield sample
|
| 459 |
+
|
| 460 |
+
def _get_next_beatmap(self, audio_samples, beatmap_path: Path, speed: float) -> dict:
|
| 461 |
+
frames, frame_times = self._get_frames(audio_samples)
|
| 462 |
+
osu_beatmap = Beatmap.from_path(beatmap_path)
|
| 463 |
+
|
| 464 |
+
if osu_beatmap.beatmap_id not in self.tokenizer.beatmap_mapper:
|
| 465 |
+
return
|
| 466 |
+
|
| 467 |
+
extra_data = {
|
| 468 |
+
"labels": self.tokenizer.mapper_idx[self.tokenizer.beatmap_mapper[osu_beatmap.beatmap_id]],
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
flip_x, flip_y = False, False
|
| 472 |
+
if self.args.augment_flip:
|
| 473 |
+
flip_x, flip_y = random.random() < 0.5, random.random() < 0.5
|
| 474 |
+
|
| 475 |
+
events, event_times = self.parser.parse(osu_beatmap, speed, flip_x, flip_y)
|
| 476 |
+
in_context = {"events": events, "event_times": event_times}
|
| 477 |
+
|
| 478 |
+
sequences = self._create_sequences(
|
| 479 |
+
frames,
|
| 480 |
+
frame_times,
|
| 481 |
+
in_context,
|
| 482 |
+
extra_data,
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
for sequence in sequences:
|
| 486 |
+
sequence = self._normalize_time_shifts(sequence)
|
| 487 |
+
sequence = self._tokenize_sequence(sequence)
|
| 488 |
+
sequence = self._pad_frame_sequence(sequence)
|
| 489 |
+
sequence = self._pad_and_split_token_sequence(sequence)
|
| 490 |
+
yield sequence
|
classifier/libs/dataset/osu_parser.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from datetime import timedelta
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import numpy.typing as npt
|
| 8 |
+
from omegaconf import DictConfig
|
| 9 |
+
from slider import Beatmap, Circle, Slider, Spinner
|
| 10 |
+
from slider.curve import Linear, Catmull, Perfect, MultiBezier
|
| 11 |
+
|
| 12 |
+
from ..tokenizer import Event, EventType, Tokenizer
|
| 13 |
+
from .data_utils import merge_events, speed_events
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class OsuParser:
|
| 17 |
+
def __init__(self, args: DictConfig, tokenizer: Tokenizer) -> None:
|
| 18 |
+
self.types_first = args.data.types_first
|
| 19 |
+
self.add_timing = args.data.add_timing
|
| 20 |
+
self.add_snapping = args.data.add_snapping
|
| 21 |
+
self.add_timing_points = args.data.add_timing_points
|
| 22 |
+
self.add_hitsounds = args.data.add_hitsounds
|
| 23 |
+
self.add_distances = args.data.add_distances
|
| 24 |
+
self.add_positions = args.data.add_positions
|
| 25 |
+
if self.add_positions:
|
| 26 |
+
self.position_precision = args.data.position_precision
|
| 27 |
+
self.position_split_axes = args.data.position_split_axes
|
| 28 |
+
x_min, x_max, y_min, y_max = args.data.position_range
|
| 29 |
+
self.x_min = x_min / self.position_precision
|
| 30 |
+
self.x_max = x_max / self.position_precision
|
| 31 |
+
self.y_min = y_min / self.position_precision
|
| 32 |
+
self.y_max = y_max / self.position_precision
|
| 33 |
+
self.x_count = self.x_max - self.x_min + 1
|
| 34 |
+
if self.add_distances:
|
| 35 |
+
dist_range = tokenizer.event_range[EventType.DISTANCE]
|
| 36 |
+
self.dist_min = dist_range.min_value
|
| 37 |
+
self.dist_max = dist_range.max_value
|
| 38 |
+
|
| 39 |
+
def parse(
|
| 40 |
+
self,
|
| 41 |
+
beatmap: Beatmap,
|
| 42 |
+
speed: float = 1.0,
|
| 43 |
+
flip_x: bool = False,
|
| 44 |
+
flip_y: bool = False
|
| 45 |
+
) -> tuple[list[Event], list[int]]:
|
| 46 |
+
# noinspection PyUnresolvedReferences
|
| 47 |
+
"""Parse an .osu beatmap.
|
| 48 |
+
|
| 49 |
+
Each hit object is parsed into a list of Event objects, in order of its
|
| 50 |
+
appearance in the beatmap. In other words, in ascending order of time.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
beatmap: Beatmap object parsed from an .osu file.
|
| 54 |
+
speed: Speed multiplier for the beatmap.
|
| 55 |
+
flip_x: Whether to flip the x-axis.
|
| 56 |
+
flip_y: Whether to flip the y-axis.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
events: List of Event object lists.
|
| 60 |
+
event_times: List of event times.
|
| 61 |
+
|
| 62 |
+
Example::
|
| 63 |
+
>>> beatmap = [
|
| 64 |
+
"64,80,11000,1,0",
|
| 65 |
+
"100,100,16000,2,0,B|200:200|250:200|250:200|300:150,2"
|
| 66 |
+
]
|
| 67 |
+
>>> events = parse(beatmap)
|
| 68 |
+
>>> print(events)
|
| 69 |
+
[
|
| 70 |
+
Event(EventType.TIME_SHIFT, 11000), Event(EventType.DISTANCE, 36), Event(EventType.CIRCLE),
|
| 71 |
+
Event(EventType.TIME_SHIFT, 16000), Event(EventType.DISTANCE, 42), Event(EventType.SLIDER_HEAD),
|
| 72 |
+
Event(EventType.TIME_SHIFT, 16500), Event(EventType.DISTANCE, 141), Event(EventType.BEZIER_ANCHOR),
|
| 73 |
+
Event(EventType.TIME_SHIFT, 17000), Event(EventType.DISTANCE, 50), Event(EventType.BEZIER_ANCHOR),
|
| 74 |
+
Event(EventType.TIME_SHIFT, 17500), Event(EventType.DISTANCE, 10), Event(EventType.BEZIER_ANCHOR),
|
| 75 |
+
Event(EventType.TIME_SHIFT, 18000), Event(EventType.DISTANCE, 64), Event(EventType.LAST _ANCHOR),
|
| 76 |
+
Event(EventType.TIME_SHIFT, 20000), Event(EventType.DISTANCE, 11), Event(EventType.SLIDER_END)
|
| 77 |
+
]
|
| 78 |
+
"""
|
| 79 |
+
hit_objects = beatmap.hit_objects(stacking=False)
|
| 80 |
+
last_pos = np.array((256, 192))
|
| 81 |
+
events = []
|
| 82 |
+
event_times = []
|
| 83 |
+
|
| 84 |
+
for hit_object in hit_objects:
|
| 85 |
+
if isinstance(hit_object, Circle):
|
| 86 |
+
last_pos = self._parse_circle(hit_object, events, event_times, last_pos, beatmap, flip_x, flip_y)
|
| 87 |
+
elif isinstance(hit_object, Slider):
|
| 88 |
+
last_pos = self._parse_slider(hit_object, events, event_times, last_pos, beatmap, flip_x, flip_y)
|
| 89 |
+
elif isinstance(hit_object, Spinner):
|
| 90 |
+
last_pos = self._parse_spinner(hit_object, events, event_times, beatmap)
|
| 91 |
+
|
| 92 |
+
if self.add_timing:
|
| 93 |
+
timing_events, timing_times = self.parse_timing(beatmap)
|
| 94 |
+
events, event_times = merge_events(timing_events, timing_times, events, event_times)
|
| 95 |
+
|
| 96 |
+
if speed != 1.0:
|
| 97 |
+
events, event_times = speed_events(events, event_times, speed)
|
| 98 |
+
|
| 99 |
+
return events, event_times
|
| 100 |
+
|
| 101 |
+
def parse_timing(self, beatmap: Beatmap, speed: float = 1.0) -> tuple[list[Event], list[int]]:
|
| 102 |
+
"""Extract all timing information from a beatmap."""
|
| 103 |
+
events = []
|
| 104 |
+
event_times = []
|
| 105 |
+
hit_objects = beatmap.hit_objects(stacking=False)
|
| 106 |
+
if len(hit_objects) == 0:
|
| 107 |
+
last_time = timedelta(milliseconds=0)
|
| 108 |
+
else:
|
| 109 |
+
last_ho = beatmap.hit_objects(stacking=False)[-1]
|
| 110 |
+
last_time = last_ho.end_time if hasattr(last_ho, "end_time") else last_ho.time
|
| 111 |
+
|
| 112 |
+
# Get all timing points with BPM changes
|
| 113 |
+
timing_points = [tp for tp in beatmap.timing_points if tp.bpm]
|
| 114 |
+
|
| 115 |
+
for i, tp in enumerate(timing_points):
|
| 116 |
+
# Generate beat and measure events until the next timing point
|
| 117 |
+
next_tp = timing_points[i + 1] if i + 1 < len(timing_points) else None
|
| 118 |
+
next_time = next_tp.offset - timedelta(milliseconds=10) if next_tp else last_time
|
| 119 |
+
time = tp.offset
|
| 120 |
+
measure_counter = 0
|
| 121 |
+
beat_delta = timedelta(milliseconds=tp.ms_per_beat)
|
| 122 |
+
while time <= next_time:
|
| 123 |
+
if self.add_timing_points and measure_counter == 0:
|
| 124 |
+
event_type = EventType.TIMING_POINT
|
| 125 |
+
elif measure_counter % tp.meter == 0:
|
| 126 |
+
event_type = EventType.MEASURE
|
| 127 |
+
else:
|
| 128 |
+
event_type = EventType.BEAT
|
| 129 |
+
|
| 130 |
+
self._add_group(
|
| 131 |
+
event_type,
|
| 132 |
+
time,
|
| 133 |
+
events,
|
| 134 |
+
event_times,
|
| 135 |
+
beatmap,
|
| 136 |
+
time_event=True,
|
| 137 |
+
add_snap=False,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
measure_counter += 1
|
| 141 |
+
time += beat_delta
|
| 142 |
+
|
| 143 |
+
if speed != 1.0:
|
| 144 |
+
events, event_times = speed_events(events, event_times, speed)
|
| 145 |
+
|
| 146 |
+
return events, event_times
|
| 147 |
+
|
| 148 |
+
@staticmethod
|
| 149 |
+
def uninherited_point_at(time: timedelta, beatmap: Beatmap):
|
| 150 |
+
tp = beatmap.timing_point_at(time)
|
| 151 |
+
return tp if tp.parent is None else tp.parent
|
| 152 |
+
|
| 153 |
+
@staticmethod
|
| 154 |
+
def hitsound_point_at(time: timedelta, beatmap: Beatmap):
|
| 155 |
+
hs_query = time + timedelta(milliseconds=5)
|
| 156 |
+
return beatmap.timing_point_at(hs_query)
|
| 157 |
+
|
| 158 |
+
def _add_time_event(self, time: timedelta, beatmap: Beatmap, events: list[Event], event_times: list[int],
|
| 159 |
+
add_snap: bool = True) -> None:
|
| 160 |
+
"""Add a snapping event to the event list.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
time: Time of the snapping event.
|
| 164 |
+
beatmap: Beatmap object.
|
| 165 |
+
events: List of events to add to.
|
| 166 |
+
add_snap: Whether to add a snapping event.
|
| 167 |
+
"""
|
| 168 |
+
time_ms = int(time.total_seconds() * 1000 + 1e-5)
|
| 169 |
+
events.append(Event(EventType.TIME_SHIFT, time_ms))
|
| 170 |
+
event_times.append(time_ms)
|
| 171 |
+
|
| 172 |
+
if not add_snap or not self.add_snapping:
|
| 173 |
+
return
|
| 174 |
+
|
| 175 |
+
if len(beatmap.timing_points) > 0:
|
| 176 |
+
tp = self.uninherited_point_at(time, beatmap)
|
| 177 |
+
beats = (time - tp.offset).total_seconds() * 1000 / tp.ms_per_beat
|
| 178 |
+
snapping = 0
|
| 179 |
+
for i in range(1, 17):
|
| 180 |
+
# If the difference between the time and the snapped time is less than 2 ms, that is the correct snapping
|
| 181 |
+
if abs(beats - round(beats * i) / i) * tp.ms_per_beat < 2:
|
| 182 |
+
snapping = i
|
| 183 |
+
break
|
| 184 |
+
else:
|
| 185 |
+
snapping = 0
|
| 186 |
+
|
| 187 |
+
events.append(Event(EventType.SNAPPING, snapping))
|
| 188 |
+
event_times.append(time_ms)
|
| 189 |
+
|
| 190 |
+
def _add_hitsound_event(self, time: timedelta, group_time: int, hitsound: int, addition: str, beatmap: Beatmap,
|
| 191 |
+
events: list[Event], event_times: list[int]) -> None:
|
| 192 |
+
if not self.add_hitsounds:
|
| 193 |
+
return
|
| 194 |
+
|
| 195 |
+
if len(beatmap.timing_points) > 0:
|
| 196 |
+
tp = self.hitsound_point_at(time, beatmap)
|
| 197 |
+
tp_sample_set = tp.sample_type if tp.sample_type != 0 else 2 # Inherit to soft sample set
|
| 198 |
+
tp_volume = tp.volume
|
| 199 |
+
else:
|
| 200 |
+
tp_sample_set = 2
|
| 201 |
+
tp_volume = 100
|
| 202 |
+
|
| 203 |
+
addition_split = addition.split(":")
|
| 204 |
+
sample_set = int(addition_split[0]) if addition_split[0] != "0" else tp_sample_set
|
| 205 |
+
addition_set = int(addition_split[1]) if addition_split[1] != "0" else sample_set
|
| 206 |
+
|
| 207 |
+
sample_set = sample_set if 0 < sample_set < 4 else 1 # Overflow default to normal sample set
|
| 208 |
+
addition_set = addition_set if 0 < addition_set < 4 else 1 # Overflow default to normal sample set
|
| 209 |
+
hitsound = hitsound & 14 # Only take the bits for normal, whistle, and finish
|
| 210 |
+
|
| 211 |
+
hitsound_idx = hitsound // 2 + 8 * (sample_set - 1) + 24 * (addition_set - 1)
|
| 212 |
+
|
| 213 |
+
events.append(Event(EventType.HITSOUND, hitsound_idx))
|
| 214 |
+
events.append(Event(EventType.VOLUME, tp_volume))
|
| 215 |
+
event_times.append(group_time)
|
| 216 |
+
event_times.append(group_time)
|
| 217 |
+
|
| 218 |
+
def _clip_dist(self, dist: int) -> int:
|
| 219 |
+
"""Clip distance to valid range."""
|
| 220 |
+
return int(np.clip(dist, self.dist_min, self.dist_max))
|
| 221 |
+
|
| 222 |
+
def _scale_clip_pos(self, pos: npt.NDArray) -> Tuple[int, int]:
|
| 223 |
+
"""Clip position to valid range."""
|
| 224 |
+
p = pos / self.position_precision
|
| 225 |
+
return int(np.clip(p[0], self.x_min, self.x_max)), int(np.clip(p[1], self.y_min, self.y_max))
|
| 226 |
+
|
| 227 |
+
def _add_position_event(self, pos: npt.NDArray, last_pos: npt.NDArray, time: timedelta, events: list[Event],
|
| 228 |
+
event_times: list[int], flip_x: bool, flip_y: bool) -> npt.NDArray:
|
| 229 |
+
time_ms = int(time.total_seconds() * 1000 + 1e-5)
|
| 230 |
+
if self.add_distances:
|
| 231 |
+
dist = self._clip_dist(np.linalg.norm(pos - last_pos))
|
| 232 |
+
events.append(Event(EventType.DISTANCE, dist))
|
| 233 |
+
event_times.append(time_ms)
|
| 234 |
+
|
| 235 |
+
if self.add_positions:
|
| 236 |
+
pos_modified = pos.copy()
|
| 237 |
+
if flip_x:
|
| 238 |
+
pos_modified[0] = 512 - pos_modified[0]
|
| 239 |
+
if flip_y:
|
| 240 |
+
pos_modified[1] = 384 - pos_modified[1]
|
| 241 |
+
|
| 242 |
+
p = self._scale_clip_pos(pos_modified)
|
| 243 |
+
if self.position_split_axes:
|
| 244 |
+
events.append(Event(EventType.POS_X, p[0]))
|
| 245 |
+
events.append(Event(EventType.POS_Y, p[1]))
|
| 246 |
+
event_times.append(time_ms)
|
| 247 |
+
event_times.append(time_ms)
|
| 248 |
+
else:
|
| 249 |
+
events.append(Event(EventType.POS, (p[0] - self.x_min) + (p[1] - self.y_min) * self.x_count))
|
| 250 |
+
event_times.append(time_ms)
|
| 251 |
+
|
| 252 |
+
return pos
|
| 253 |
+
|
| 254 |
+
def _add_group(
|
| 255 |
+
self,
|
| 256 |
+
event_type: EventType,
|
| 257 |
+
time: timedelta,
|
| 258 |
+
events: list[Event],
|
| 259 |
+
event_times: list[int],
|
| 260 |
+
beatmap: Beatmap,
|
| 261 |
+
*,
|
| 262 |
+
time_event: bool = False,
|
| 263 |
+
add_snap=True,
|
| 264 |
+
pos: npt.NDArray = None,
|
| 265 |
+
last_pos: npt.NDArray = None,
|
| 266 |
+
new_combo: bool = False,
|
| 267 |
+
hitsound_ref_times: list[timedelta] = None,
|
| 268 |
+
hitsounds: list[int] = None,
|
| 269 |
+
additions: list[str] = None,
|
| 270 |
+
flip_x: bool = False,
|
| 271 |
+
flip_y: bool = False,
|
| 272 |
+
) -> npt.NDArray:
|
| 273 |
+
"""Add a group of events to the event list."""
|
| 274 |
+
time_ms = int(time.total_seconds() * 1000 + 1e-5) if time is not None else None
|
| 275 |
+
|
| 276 |
+
if self.types_first:
|
| 277 |
+
events.append(Event(event_type))
|
| 278 |
+
event_times.append(time_ms)
|
| 279 |
+
if time_event:
|
| 280 |
+
self._add_time_event(time, beatmap, events, event_times, add_snap)
|
| 281 |
+
if pos is not None:
|
| 282 |
+
last_pos = self._add_position_event(pos, last_pos, time, events, event_times, flip_x, flip_y)
|
| 283 |
+
if new_combo:
|
| 284 |
+
events.append(Event(EventType.NEW_COMBO))
|
| 285 |
+
event_times.append(time_ms)
|
| 286 |
+
if hitsound_ref_times is not None:
|
| 287 |
+
for i, ref_time in enumerate(hitsound_ref_times):
|
| 288 |
+
self._add_hitsound_event(ref_time, time_ms, hitsounds[i], additions[i], beatmap, events, event_times)
|
| 289 |
+
if not self.types_first:
|
| 290 |
+
events.append(Event(event_type))
|
| 291 |
+
event_times.append(time_ms)
|
| 292 |
+
|
| 293 |
+
return last_pos
|
| 294 |
+
|
| 295 |
+
def _parse_circle(self, circle: Circle, events: list[Event], event_times: list[int], last_pos: npt.NDArray,
|
| 296 |
+
beatmap: Beatmap, flip_x: bool, flip_y: bool) -> npt.NDArray:
|
| 297 |
+
"""Parse a circle hit object.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
circle: Circle object.
|
| 301 |
+
events: List of events to add to.
|
| 302 |
+
last_pos: Last position of the hit objects.
|
| 303 |
+
|
| 304 |
+
Returns:
|
| 305 |
+
pos: Position of the circle.
|
| 306 |
+
"""
|
| 307 |
+
return self._add_group(
|
| 308 |
+
EventType.CIRCLE,
|
| 309 |
+
circle.time,
|
| 310 |
+
events,
|
| 311 |
+
event_times,
|
| 312 |
+
beatmap,
|
| 313 |
+
time_event=True,
|
| 314 |
+
pos=np.array(circle.position),
|
| 315 |
+
last_pos=last_pos,
|
| 316 |
+
new_combo=circle.new_combo,
|
| 317 |
+
hitsound_ref_times=[circle.time],
|
| 318 |
+
hitsounds=[circle.hitsound],
|
| 319 |
+
additions=[circle.addition],
|
| 320 |
+
flip_x=flip_x,
|
| 321 |
+
flip_y=flip_y,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
def _parse_slider(self, slider: Slider, events: list[Event], event_times: list[int], last_pos: npt.NDArray,
|
| 325 |
+
beatmap: Beatmap, flip_x: bool, flip_y: bool) -> npt.NDArray:
|
| 326 |
+
"""Parse a slider hit object.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
slider: Slider object.
|
| 330 |
+
events: List of events to add to.
|
| 331 |
+
last_pos: Last position of the hit objects.
|
| 332 |
+
|
| 333 |
+
Returns:
|
| 334 |
+
pos: Last position of the slider.
|
| 335 |
+
"""
|
| 336 |
+
# Ignore sliders which are too big
|
| 337 |
+
if len(slider.curve.points) >= 100:
|
| 338 |
+
return last_pos
|
| 339 |
+
|
| 340 |
+
last_pos = self._add_group(
|
| 341 |
+
EventType.SLIDER_HEAD,
|
| 342 |
+
slider.time,
|
| 343 |
+
events,
|
| 344 |
+
event_times,
|
| 345 |
+
beatmap,
|
| 346 |
+
time_event=True,
|
| 347 |
+
pos=np.array(slider.position),
|
| 348 |
+
last_pos=last_pos,
|
| 349 |
+
new_combo=slider.new_combo,
|
| 350 |
+
hitsound_ref_times=[slider.time],
|
| 351 |
+
hitsounds=[slider.edge_sounds[0] if len(slider.edge_sounds) > 0 else 0],
|
| 352 |
+
additions=[slider.edge_additions[0] if len(slider.edge_additions) > 0 else '0:0'],
|
| 353 |
+
flip_x=flip_x,
|
| 354 |
+
flip_y=flip_y,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
duration: timedelta = (slider.end_time - slider.time) / slider.repeat
|
| 358 |
+
control_point_count = len(slider.curve.points)
|
| 359 |
+
|
| 360 |
+
def append_control_points(event_type: EventType, last_pos: npt.NDArray = last_pos) -> npt.NDArray:
|
| 361 |
+
for i in range(1, control_point_count - 1):
|
| 362 |
+
last_pos = add_anchor(event_type, i, last_pos)
|
| 363 |
+
|
| 364 |
+
return last_pos
|
| 365 |
+
|
| 366 |
+
def add_anchor(event_type: EventType, i: int, last_pos: npt.NDArray) -> npt.NDArray:
|
| 367 |
+
return self._add_group(
|
| 368 |
+
event_type,
|
| 369 |
+
slider.time + i / (control_point_count - 1) * duration,
|
| 370 |
+
events,
|
| 371 |
+
event_times,
|
| 372 |
+
beatmap,
|
| 373 |
+
pos=np.array(slider.curve.points[i]),
|
| 374 |
+
last_pos=last_pos,
|
| 375 |
+
flip_x=flip_x,
|
| 376 |
+
flip_y=flip_y,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
if isinstance(slider.curve, Linear):
|
| 380 |
+
last_pos = append_control_points(EventType.RED_ANCHOR, last_pos)
|
| 381 |
+
elif isinstance(slider.curve, Catmull):
|
| 382 |
+
last_pos = append_control_points(EventType.CATMULL_ANCHOR, last_pos)
|
| 383 |
+
elif isinstance(slider.curve, Perfect):
|
| 384 |
+
last_pos = append_control_points(EventType.PERFECT_ANCHOR, last_pos)
|
| 385 |
+
elif isinstance(slider.curve, MultiBezier):
|
| 386 |
+
for i in range(1, control_point_count - 1):
|
| 387 |
+
if slider.curve.points[i] == slider.curve.points[i + 1]:
|
| 388 |
+
last_pos = add_anchor(EventType.RED_ANCHOR, i, last_pos)
|
| 389 |
+
elif slider.curve.points[i] != slider.curve.points[i - 1]:
|
| 390 |
+
last_pos = add_anchor(EventType.BEZIER_ANCHOR, i, last_pos)
|
| 391 |
+
|
| 392 |
+
# Add body hitsounds and remaining edge hitsounds
|
| 393 |
+
last_pos = self._add_group(
|
| 394 |
+
EventType.LAST_ANCHOR,
|
| 395 |
+
slider.time + duration,
|
| 396 |
+
events,
|
| 397 |
+
event_times,
|
| 398 |
+
beatmap,
|
| 399 |
+
time_event=True,
|
| 400 |
+
pos=np.array(slider.curve.points[-1]),
|
| 401 |
+
last_pos=last_pos,
|
| 402 |
+
hitsound_ref_times=[slider.time + timedelta(milliseconds=1)] + [slider.time + i * duration for i in
|
| 403 |
+
range(1, slider.repeat)],
|
| 404 |
+
hitsounds=[slider.hitsound] + [slider.edge_sounds[i] if len(slider.edge_sounds) > i else 0 for i in
|
| 405 |
+
range(1, slider.repeat)],
|
| 406 |
+
additions=[slider.addition] + [slider.edge_additions[i] if len(slider.edge_additions) > i else '0:0' for i
|
| 407 |
+
in range(1, slider.repeat)],
|
| 408 |
+
flip_x=flip_x,
|
| 409 |
+
flip_y=flip_y,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
return self._add_group(
|
| 413 |
+
EventType.SLIDER_END,
|
| 414 |
+
slider.end_time,
|
| 415 |
+
events,
|
| 416 |
+
event_times,
|
| 417 |
+
beatmap,
|
| 418 |
+
time_event=True,
|
| 419 |
+
pos=np.array(slider.curve(1)),
|
| 420 |
+
last_pos=last_pos,
|
| 421 |
+
hitsound_ref_times=[slider.end_time],
|
| 422 |
+
hitsounds=[slider.edge_sounds[-1] if len(slider.edge_sounds) > 0 else 0],
|
| 423 |
+
additions=[slider.edge_additions[-1] if len(slider.edge_additions) > 0 else '0:0'],
|
| 424 |
+
flip_x=flip_x,
|
| 425 |
+
flip_y=flip_y,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
def _parse_spinner(self, spinner: Spinner, events: list[Event], event_times: list[int],
|
| 429 |
+
beatmap: Beatmap) -> npt.NDArray:
|
| 430 |
+
"""Parse a spinner hit object.
|
| 431 |
+
|
| 432 |
+
Args:
|
| 433 |
+
spinner: Spinner object.
|
| 434 |
+
events: List of events to add to.
|
| 435 |
+
|
| 436 |
+
Returns:
|
| 437 |
+
pos: Last position of the spinner.
|
| 438 |
+
"""
|
| 439 |
+
self._add_group(
|
| 440 |
+
EventType.SPINNER,
|
| 441 |
+
spinner.time,
|
| 442 |
+
events,
|
| 443 |
+
event_times,
|
| 444 |
+
beatmap,
|
| 445 |
+
time_event=True,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
self._add_group(
|
| 449 |
+
EventType.SPINNER_END,
|
| 450 |
+
spinner.end_time,
|
| 451 |
+
events,
|
| 452 |
+
event_times,
|
| 453 |
+
beatmap,
|
| 454 |
+
time_event=True,
|
| 455 |
+
hitsound_ref_times=[spinner.end_time],
|
| 456 |
+
hitsounds=[spinner.hitsound],
|
| 457 |
+
additions=[spinner.addition],
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
return np.array((256, 192))
|
classifier/libs/model/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .model import OsuClassifier
|
classifier/libs/model/model.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from omegaconf import DictConfig
|
| 9 |
+
from transformers import T5Config, WhisperConfig, T5Model, WhisperModel
|
| 10 |
+
from transformers.modeling_outputs import Seq2SeqModelOutput
|
| 11 |
+
|
| 12 |
+
from .spectrogram import MelSpectrogram
|
| 13 |
+
from ..tokenizer import Tokenizer
|
| 14 |
+
|
| 15 |
+
LABEL_IGNORE_ID = -100
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class OsuClassifierOutput:
|
| 20 |
+
loss: Optional[torch.FloatTensor] = None
|
| 21 |
+
logits: Optional[torch.FloatTensor] = None
|
| 22 |
+
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
| 23 |
+
decoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
| 24 |
+
feature_vector: Optional[torch.FloatTensor] = None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_backbone_model(args, tokenizer: Tokenizer):
|
| 28 |
+
if args.model.name.startswith("google/t5"):
|
| 29 |
+
config = T5Config.from_pretrained(args.model.name)
|
| 30 |
+
elif args.model.name.startswith("openai/whisper"):
|
| 31 |
+
config = WhisperConfig.from_pretrained(args.model.name)
|
| 32 |
+
else:
|
| 33 |
+
raise NotImplementedError
|
| 34 |
+
|
| 35 |
+
config.vocab_size = tokenizer.vocab_size
|
| 36 |
+
|
| 37 |
+
if hasattr(args.model, "overwrite"):
|
| 38 |
+
for k, v in args.model.overwrite.items():
|
| 39 |
+
assert hasattr(config, k), f"config does not have attribute {k}"
|
| 40 |
+
setattr(config, k, v)
|
| 41 |
+
|
| 42 |
+
if hasattr(args.model, "add_config"):
|
| 43 |
+
for k, v in args.model.add_config.items():
|
| 44 |
+
assert not hasattr(config, k), f"config already has attribute {k}"
|
| 45 |
+
setattr(config, k, v)
|
| 46 |
+
|
| 47 |
+
if args.model.name.startswith("google/t5"):
|
| 48 |
+
model = T5Model(config)
|
| 49 |
+
elif args.model.name.startswith("openai/whisper"):
|
| 50 |
+
config.use_cache = False
|
| 51 |
+
config.num_mel_bins = config.d_model
|
| 52 |
+
config.pad_token_id = tokenizer.pad_id
|
| 53 |
+
config.max_source_positions = args.data.src_seq_len // 2
|
| 54 |
+
config.max_target_positions = args.data.tgt_seq_len
|
| 55 |
+
model = WhisperModel(config)
|
| 56 |
+
else:
|
| 57 |
+
raise NotImplementedError
|
| 58 |
+
|
| 59 |
+
return model, config.d_model
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class OsuClassifier(nn.Module):
|
| 63 |
+
__slots__ = [
|
| 64 |
+
"spectrogram",
|
| 65 |
+
"decoder_embedder",
|
| 66 |
+
"encoder_embedder",
|
| 67 |
+
"transformer",
|
| 68 |
+
"style_embedder",
|
| 69 |
+
"num_classes",
|
| 70 |
+
"input_features",
|
| 71 |
+
"projector",
|
| 72 |
+
"classifier",
|
| 73 |
+
"vocab_size",
|
| 74 |
+
"loss_fn",
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
def __init__(self, args: DictConfig, tokenizer: Tokenizer):
|
| 78 |
+
super().__init__()
|
| 79 |
+
|
| 80 |
+
self.transformer, d_model = get_backbone_model(args, tokenizer)
|
| 81 |
+
self.num_classes = tokenizer.num_classes
|
| 82 |
+
self.input_features = args.model.input_features
|
| 83 |
+
|
| 84 |
+
self.decoder_embedder = nn.Embedding(tokenizer.vocab_size, d_model)
|
| 85 |
+
self.decoder_embedder.weight.data.normal_(mean=0.0, std=1.0)
|
| 86 |
+
|
| 87 |
+
self.spectrogram = MelSpectrogram(
|
| 88 |
+
args.model.spectrogram.sample_rate, args.model.spectrogram.n_fft,
|
| 89 |
+
args.model.spectrogram.n_mels, args.model.spectrogram.hop_length
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.encoder_embedder = nn.Linear(args.model.spectrogram.n_mels, d_model)
|
| 93 |
+
|
| 94 |
+
self.projector = nn.Linear(d_model, args.model.classifier_proj_size)
|
| 95 |
+
self.classifier = nn.Linear(args.model.classifier_proj_size, tokenizer.num_classes)
|
| 96 |
+
|
| 97 |
+
self.vocab_size = tokenizer.vocab_size
|
| 98 |
+
self.loss_fn = nn.CrossEntropyLoss()
|
| 99 |
+
|
| 100 |
+
def forward(
|
| 101 |
+
self,
|
| 102 |
+
frames: Optional[torch.FloatTensor] = None,
|
| 103 |
+
decoder_input_ids: Optional[torch.Tensor] = None,
|
| 104 |
+
labels: Optional[torch.LongTensor] = None,
|
| 105 |
+
**kwargs
|
| 106 |
+
) -> OsuClassifierOutput:
|
| 107 |
+
"""
|
| 108 |
+
frames: B x L_encoder x mel_bins, float32
|
| 109 |
+
decoder_input_ids: B x L_decoder, int64
|
| 110 |
+
beatmap_id: B, int64
|
| 111 |
+
encoder_outputs: B x L_encoder x D, float32
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
frames = self.spectrogram(frames) # (N, L, M)
|
| 115 |
+
inputs_embeds = self.encoder_embedder(frames)
|
| 116 |
+
decoder_inputs_embeds = self.decoder_embedder(decoder_input_ids)
|
| 117 |
+
|
| 118 |
+
if self.input_features:
|
| 119 |
+
input_features = torch.swapaxes(inputs_embeds, 1, 2) if inputs_embeds is not None else None
|
| 120 |
+
# noinspection PyTypeChecker
|
| 121 |
+
base_output: Seq2SeqModelOutput = self.transformer.forward(input_features=input_features,
|
| 122 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
| 123 |
+
**kwargs)
|
| 124 |
+
else:
|
| 125 |
+
base_output = self.transformer.forward(inputs_embeds=inputs_embeds,
|
| 126 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
| 127 |
+
**kwargs)
|
| 128 |
+
|
| 129 |
+
# Get logits
|
| 130 |
+
hidden_states = self.projector(base_output.last_hidden_state)
|
| 131 |
+
pooled_output = hidden_states.mean(dim=1)
|
| 132 |
+
|
| 133 |
+
logits = self.classifier(pooled_output)
|
| 134 |
+
loss = None
|
| 135 |
+
|
| 136 |
+
if labels is not None:
|
| 137 |
+
loss = self.loss_fn(logits.view(-1, self.num_classes), labels.view(-1))
|
| 138 |
+
|
| 139 |
+
return OsuClassifierOutput(
|
| 140 |
+
loss=loss,
|
| 141 |
+
logits=logits,
|
| 142 |
+
encoder_last_hidden_state=base_output.encoder_last_hidden_state,
|
| 143 |
+
decoder_last_hidden_state=base_output.last_hidden_state,
|
| 144 |
+
feature_vector=pooled_output
|
| 145 |
+
)
|
classifier/libs/model/spectrogram.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from nnAudio import features
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MelSpectrogram(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
sample_rate: int = 16000,
|
| 12 |
+
n_ftt: int = 2048,
|
| 13 |
+
n_mels: int = 512,
|
| 14 |
+
hop_length: int = 128,
|
| 15 |
+
):
|
| 16 |
+
"""
|
| 17 |
+
Melspectrogram transformation layer, supports on-the-fly processing on GPU.
|
| 18 |
+
|
| 19 |
+
Attributes:
|
| 20 |
+
sample_rate: The sampling rate for the input audio.
|
| 21 |
+
n_ftt: The window size for the STFT.
|
| 22 |
+
n_mels: The number of Mel filter banks.
|
| 23 |
+
hop_length: The hop (or stride) size.
|
| 24 |
+
"""
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.transform = features.MelSpectrogram(
|
| 27 |
+
sr=sample_rate,
|
| 28 |
+
n_fft=n_ftt,
|
| 29 |
+
n_mels=n_mels,
|
| 30 |
+
hop_length=hop_length,
|
| 31 |
+
center=True,
|
| 32 |
+
fmin=0,
|
| 33 |
+
fmax=sample_rate // 2,
|
| 34 |
+
pad_mode="constant",
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def forward(self, samples: torch.tensor) -> torch.tensor:
|
| 38 |
+
"""
|
| 39 |
+
Convert a batch of audio frames into a batch of Mel spectrogram frames.
|
| 40 |
+
|
| 41 |
+
For each item in the batch:
|
| 42 |
+
1. pad left and right ends of audio by n_fft // 2.
|
| 43 |
+
2. run STFT with window size of |n_ftt| and stride of |hop_length|.
|
| 44 |
+
3. convert result into mel-scale.
|
| 45 |
+
4. therefore, n_frames = n_samples // hop_length + 1.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
samples: Audio time-series (batch size, n_samples).
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
A batch of Mel spectrograms of size (batch size, n_frames, n_mels).
|
| 52 |
+
"""
|
| 53 |
+
spectrogram = self.transform(samples)
|
| 54 |
+
spectrogram = spectrogram.permute(0, 2, 1)
|
| 55 |
+
return spectrogram
|
classifier/libs/tokenizer/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .event import *
|
| 2 |
+
from .tokenizer import Tokenizer
|
classifier/libs/tokenizer/event.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import dataclasses
|
| 4 |
+
from enum import Enum
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class EventType(Enum):
|
| 8 |
+
TIME_SHIFT = "t"
|
| 9 |
+
SNAPPING = "snap"
|
| 10 |
+
DISTANCE = "dist"
|
| 11 |
+
NEW_COMBO = "new_combo"
|
| 12 |
+
HITSOUND = "hitsound"
|
| 13 |
+
VOLUME = "volume"
|
| 14 |
+
CIRCLE = "circle"
|
| 15 |
+
SPINNER = "spinner"
|
| 16 |
+
SPINNER_END = "spinner_end"
|
| 17 |
+
SLIDER_HEAD = "slider_head"
|
| 18 |
+
BEZIER_ANCHOR = "bezier_anchor"
|
| 19 |
+
PERFECT_ANCHOR = "perfect_anchor"
|
| 20 |
+
CATMULL_ANCHOR = "catmull_anchor"
|
| 21 |
+
RED_ANCHOR = "red_anchor"
|
| 22 |
+
LAST_ANCHOR = "last_anchor"
|
| 23 |
+
SLIDER_END = "slider_end"
|
| 24 |
+
BEAT = "beat"
|
| 25 |
+
MEASURE = "measure"
|
| 26 |
+
TIMING_POINT = "timing_point"
|
| 27 |
+
STYLE = "style"
|
| 28 |
+
DIFFICULTY = "difficulty"
|
| 29 |
+
MAPPER = "mapper"
|
| 30 |
+
DESCRIPTOR = "descriptor"
|
| 31 |
+
POS_X = "pos_x"
|
| 32 |
+
POS_Y = "pos_y"
|
| 33 |
+
POS = "pos"
|
| 34 |
+
CS = "cs"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclasses.dataclass
|
| 38 |
+
class EventRange:
|
| 39 |
+
type: EventType
|
| 40 |
+
min_value: int
|
| 41 |
+
max_value: int
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclasses.dataclass
|
| 45 |
+
class Event:
|
| 46 |
+
type: EventType
|
| 47 |
+
value: int = 0
|
| 48 |
+
|
| 49 |
+
def __repr__(self) -> str:
|
| 50 |
+
return f"{self.type.value}{self.value}"
|
| 51 |
+
|
| 52 |
+
def __str__(self) -> str:
|
| 53 |
+
return f"{self.type.value}{self.value}"
|
classifier/libs/tokenizer/tokenizer.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from omegaconf import DictConfig
|
| 5 |
+
|
| 6 |
+
from .event import Event, EventType, EventRange
|
| 7 |
+
|
| 8 |
+
MILISECONDS_PER_SECOND = 1000
|
| 9 |
+
MILISECONDS_PER_STEP = 10
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Tokenizer:
|
| 13 |
+
__slots__ = [
|
| 14 |
+
"offset",
|
| 15 |
+
"event_ranges",
|
| 16 |
+
"input_event_ranges",
|
| 17 |
+
"num_classes",
|
| 18 |
+
"num_diff_classes",
|
| 19 |
+
"max_difficulty",
|
| 20 |
+
"event_range",
|
| 21 |
+
"event_start",
|
| 22 |
+
"event_end",
|
| 23 |
+
"vocab_size",
|
| 24 |
+
"beatmap_idx",
|
| 25 |
+
"mapper_idx",
|
| 26 |
+
"beatmap_mapper",
|
| 27 |
+
"num_mapper_classes",
|
| 28 |
+
"beatmap_descriptors",
|
| 29 |
+
"descriptor_idx",
|
| 30 |
+
"num_descriptor_classes",
|
| 31 |
+
"num_cs_classes",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
def __init__(self, args: DictConfig = None):
|
| 35 |
+
"""Fixed vocabulary tokenizer."""
|
| 36 |
+
self.offset = 1
|
| 37 |
+
self.event_ranges: list[EventRange] = [
|
| 38 |
+
EventRange(EventType.TIME_SHIFT, 0, 1024),
|
| 39 |
+
EventRange(EventType.SNAPPING, 0, 16),
|
| 40 |
+
EventRange(EventType.DISTANCE, 0, 640),
|
| 41 |
+
]
|
| 42 |
+
self.num_classes = 0
|
| 43 |
+
self.beatmap_mapper: dict[int, int] = {} # beatmap_id -> mapper_id
|
| 44 |
+
self.mapper_idx: dict[int, int] = {} # mapper_id -> mapper_idx
|
| 45 |
+
|
| 46 |
+
if args is not None:
|
| 47 |
+
miliseconds_per_sequence = ((args.data.src_seq_len - 1) * args.model.spectrogram.hop_length *
|
| 48 |
+
MILISECONDS_PER_SECOND / args.model.spectrogram.sample_rate)
|
| 49 |
+
max_time_shift = int(miliseconds_per_sequence / MILISECONDS_PER_STEP)
|
| 50 |
+
min_time_shift = 0
|
| 51 |
+
|
| 52 |
+
self.event_ranges = [
|
| 53 |
+
EventRange(EventType.TIME_SHIFT, min_time_shift, max_time_shift),
|
| 54 |
+
EventRange(EventType.SNAPPING, 0, 16),
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
self._init_mapper_idx(args)
|
| 58 |
+
|
| 59 |
+
if args.data.add_distances:
|
| 60 |
+
self.event_ranges.append(EventRange(EventType.DISTANCE, 0, 640))
|
| 61 |
+
|
| 62 |
+
if args.data.add_positions:
|
| 63 |
+
p = args.data.position_precision
|
| 64 |
+
x_min, x_max, y_min, y_max = args.data.position_range
|
| 65 |
+
x_min, x_max, y_min, y_max = x_min // p, x_max // p, y_min // p, y_max // p
|
| 66 |
+
|
| 67 |
+
if args.data.position_split_axes:
|
| 68 |
+
self.event_ranges.append(EventRange(EventType.POS_X, x_min, x_max))
|
| 69 |
+
self.event_ranges.append(EventRange(EventType.POS_Y, y_min, y_max))
|
| 70 |
+
else:
|
| 71 |
+
x_count = x_max - x_min + 1
|
| 72 |
+
y_count = y_max - y_min + 1
|
| 73 |
+
self.event_ranges.append(EventRange(EventType.POS, 0, x_count * y_count - 1))
|
| 74 |
+
|
| 75 |
+
self.event_ranges: list[EventRange] = self.event_ranges + [
|
| 76 |
+
EventRange(EventType.NEW_COMBO, 0, 0),
|
| 77 |
+
EventRange(EventType.HITSOUND, 0, 2 ** 3 * 3 * 3),
|
| 78 |
+
EventRange(EventType.VOLUME, 0, 100),
|
| 79 |
+
EventRange(EventType.CIRCLE, 0, 0),
|
| 80 |
+
EventRange(EventType.SPINNER, 0, 0),
|
| 81 |
+
EventRange(EventType.SPINNER_END, 0, 0),
|
| 82 |
+
EventRange(EventType.SLIDER_HEAD, 0, 0),
|
| 83 |
+
EventRange(EventType.BEZIER_ANCHOR, 0, 0),
|
| 84 |
+
EventRange(EventType.PERFECT_ANCHOR, 0, 0),
|
| 85 |
+
EventRange(EventType.CATMULL_ANCHOR, 0, 0),
|
| 86 |
+
EventRange(EventType.RED_ANCHOR, 0, 0),
|
| 87 |
+
EventRange(EventType.LAST_ANCHOR, 0, 0),
|
| 88 |
+
EventRange(EventType.SLIDER_END, 0, 0),
|
| 89 |
+
EventRange(EventType.BEAT, 0, 0),
|
| 90 |
+
EventRange(EventType.MEASURE, 0, 0),
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
if args is not None and args.data.add_timing_points:
|
| 94 |
+
self.event_ranges.append(EventRange(EventType.TIMING_POINT, 0, 0))
|
| 95 |
+
|
| 96 |
+
self.event_range: dict[EventType, EventRange] = {er.type: er for er in self.event_ranges}
|
| 97 |
+
|
| 98 |
+
self.event_start: dict[EventType, int] = {}
|
| 99 |
+
self.event_end: dict[EventType, int] = {}
|
| 100 |
+
offset = self.offset
|
| 101 |
+
for er in self.event_ranges:
|
| 102 |
+
self.event_start[er.type] = offset
|
| 103 |
+
offset += er.max_value - er.min_value + 1
|
| 104 |
+
self.event_end[er.type] = offset
|
| 105 |
+
|
| 106 |
+
self.vocab_size: int = self.offset + sum(
|
| 107 |
+
er.max_value - er.min_value + 1 for er in self.event_ranges
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def pad_id(self) -> int:
|
| 112 |
+
"""[PAD] token for padding."""
|
| 113 |
+
return 0
|
| 114 |
+
|
| 115 |
+
def decode(self, token_id: int) -> Event:
|
| 116 |
+
"""Converts token ids into Event objects."""
|
| 117 |
+
offset = self.offset
|
| 118 |
+
for er in self.event_ranges:
|
| 119 |
+
if offset <= token_id <= offset + er.max_value - er.min_value:
|
| 120 |
+
return Event(type=er.type, value=er.min_value + token_id - offset)
|
| 121 |
+
offset += er.max_value - er.min_value + 1
|
| 122 |
+
for er in self.input_event_ranges:
|
| 123 |
+
if offset <= token_id <= offset + er.max_value - er.min_value:
|
| 124 |
+
return Event(type=er.type, value=er.min_value + token_id - offset)
|
| 125 |
+
offset += er.max_value - er.min_value + 1
|
| 126 |
+
|
| 127 |
+
raise ValueError(f"id {token_id} is not mapped to any event")
|
| 128 |
+
|
| 129 |
+
def encode(self, event: Event) -> int:
|
| 130 |
+
"""Converts Event objects into token ids."""
|
| 131 |
+
if event.type not in self.event_range:
|
| 132 |
+
raise ValueError(f"unknown event type: {event.type}")
|
| 133 |
+
|
| 134 |
+
er = self.event_range[event.type]
|
| 135 |
+
offset = self.event_start[event.type]
|
| 136 |
+
|
| 137 |
+
if not er.min_value <= event.value <= er.max_value:
|
| 138 |
+
raise ValueError(
|
| 139 |
+
f"event value {event.value} is not within range "
|
| 140 |
+
f"[{er.min_value}, {er.max_value}] for event type {event.type}"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
return offset + event.value - er.min_value
|
| 144 |
+
|
| 145 |
+
def event_type_range(self, event_type: EventType) -> tuple[int, int]:
|
| 146 |
+
"""Get the token id range of each Event type."""
|
| 147 |
+
if event_type not in self.event_range:
|
| 148 |
+
raise ValueError(f"unknown event type: {event_type}")
|
| 149 |
+
|
| 150 |
+
er = self.event_range[event_type]
|
| 151 |
+
offset = self.event_start[event_type]
|
| 152 |
+
return offset, offset + (er.max_value - er.min_value)
|
| 153 |
+
|
| 154 |
+
def _init_mapper_idx(self, args):
|
| 155 |
+
""""Indexes beatmap mappers and mapper idx."""
|
| 156 |
+
if args is None or "mappers_path" not in args.data:
|
| 157 |
+
raise ValueError("mappers_path not found in args")
|
| 158 |
+
|
| 159 |
+
path = Path(args.data.mappers_path)
|
| 160 |
+
|
| 161 |
+
if not path.exists():
|
| 162 |
+
raise ValueError(f"mappers_path {path} not found")
|
| 163 |
+
|
| 164 |
+
# Load JSON data from file
|
| 165 |
+
with open(path, 'r') as file:
|
| 166 |
+
data = json.load(file)
|
| 167 |
+
|
| 168 |
+
# Populate beatmap_mapper
|
| 169 |
+
for item in data:
|
| 170 |
+
self.beatmap_mapper[item['id']] = item['user_id']
|
| 171 |
+
|
| 172 |
+
# Get unique user_ids from beatmap_mapper values
|
| 173 |
+
unique_user_ids = list(set(self.beatmap_mapper.values()))
|
| 174 |
+
|
| 175 |
+
# Create mapper_idx
|
| 176 |
+
self.mapper_idx = {user_id: idx for idx, user_id in enumerate(unique_user_ids)}
|
| 177 |
+
self.num_classes = len(unique_user_ids)
|
| 178 |
+
|
| 179 |
+
def state_dict(self):
|
| 180 |
+
return {
|
| 181 |
+
"offset": self.offset,
|
| 182 |
+
"event_ranges": self.event_ranges,
|
| 183 |
+
"num_classes": self.num_classes,
|
| 184 |
+
"event_range": self.event_range,
|
| 185 |
+
"event_start": self.event_start,
|
| 186 |
+
"event_end": self.event_end,
|
| 187 |
+
"vocab_size": self.vocab_size,
|
| 188 |
+
"beatmap_mapper": self.beatmap_mapper,
|
| 189 |
+
"mapper_idx": self.mapper_idx,
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
def load_state_dict(self, state_dict):
|
| 193 |
+
self.offset = state_dict["offset"]
|
| 194 |
+
self.event_ranges = state_dict["event_ranges"]
|
| 195 |
+
self.num_classes = state_dict["num_classes"]
|
| 196 |
+
self.event_range = state_dict["event_range"]
|
| 197 |
+
self.event_start = state_dict["event_start"]
|
| 198 |
+
self.event_end = state_dict["event_end"]
|
| 199 |
+
self.vocab_size = state_dict["vocab_size"]
|
| 200 |
+
self.beatmap_mapper = state_dict["beatmap_mapper"]
|
| 201 |
+
self.mapper_idx = state_dict["mapper_idx"]
|
classifier/libs/utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .model_utils import *
|
classifier/libs/utils/model_utils.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import lightning
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torchmetrics
|
| 8 |
+
from omegaconf import DictConfig
|
| 9 |
+
from torch.optim import Optimizer, AdamW
|
| 10 |
+
from torch.optim.lr_scheduler import (
|
| 11 |
+
LRScheduler,
|
| 12 |
+
SequentialLR,
|
| 13 |
+
LinearLR,
|
| 14 |
+
CosineAnnealingLR,
|
| 15 |
+
)
|
| 16 |
+
from torch.utils.data import DataLoader
|
| 17 |
+
from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput
|
| 18 |
+
from transformers.utils import cached_file
|
| 19 |
+
|
| 20 |
+
import routed_pickle
|
| 21 |
+
|
| 22 |
+
from ..dataset import OrsDataset, OsuParser
|
| 23 |
+
from ..model import OsuClassifier
|
| 24 |
+
from ..model.model import OsuClassifierOutput
|
| 25 |
+
from ..tokenizer import Tokenizer
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class LitOsuClassifier(lightning.LightningModule):
|
| 29 |
+
def __init__(self, args: DictConfig, tokenizer):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.save_hyperparameters()
|
| 32 |
+
self.args = args
|
| 33 |
+
self.model: OsuClassifier = OsuClassifier(args, tokenizer)
|
| 34 |
+
|
| 35 |
+
def forward(self, **kwargs) -> OsuClassifierOutput:
|
| 36 |
+
return self.model(**kwargs)
|
| 37 |
+
|
| 38 |
+
def training_step(self, batch, batch_idx):
|
| 39 |
+
output: Seq2SeqSequenceClassifierOutput = self.model(**batch)
|
| 40 |
+
loss = output.loss
|
| 41 |
+
self.log("train_loss", loss)
|
| 42 |
+
return loss
|
| 43 |
+
|
| 44 |
+
def testy_step(self, batch, batch_idx, prefix):
|
| 45 |
+
output: Seq2SeqSequenceClassifierOutput = self.model(**batch)
|
| 46 |
+
loss = output.loss
|
| 47 |
+
preds = output.logits.argmax(dim=1)
|
| 48 |
+
labels = batch["labels"]
|
| 49 |
+
accuracy = torchmetrics.functional.accuracy(preds, labels, "multiclass", num_classes=self.args.data.num_classes)
|
| 50 |
+
accuracy_10 = torchmetrics.functional.accuracy(output.logits, labels, "multiclass", num_classes=self.args.data.num_classes, top_k=10)
|
| 51 |
+
accuracy_100 = torchmetrics.functional.accuracy(output.logits, labels, "multiclass", num_classes=self.args.data.num_classes, top_k=100)
|
| 52 |
+
self.log(f"{prefix}_loss", loss)
|
| 53 |
+
self.log(f"{prefix}_accuracy", accuracy)
|
| 54 |
+
self.log(f"{prefix}_top_10_accuracy", accuracy_10)
|
| 55 |
+
self.log(f"{prefix}_top_100_accuracy", accuracy_100)
|
| 56 |
+
return loss
|
| 57 |
+
|
| 58 |
+
def validation_step(self, batch, batch_idx):
|
| 59 |
+
return self.testy_step(batch, batch_idx, "val")
|
| 60 |
+
|
| 61 |
+
def test_step(self, batch, batch_idx):
|
| 62 |
+
return self.testy_step(batch, batch_idx, "test")
|
| 63 |
+
|
| 64 |
+
def configure_optimizers(self):
|
| 65 |
+
optimizer = get_optimizer(self.parameters(), self.args)
|
| 66 |
+
scheduler = get_scheduler(optimizer, self.args)
|
| 67 |
+
return {"optimizer": optimizer, "lr_scheduler": {
|
| 68 |
+
"scheduler": scheduler,
|
| 69 |
+
"interval": "step",
|
| 70 |
+
"frequency": 1,
|
| 71 |
+
}}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def load_ckpt(ckpt_path, route_pickle=True):
|
| 75 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 76 |
+
|
| 77 |
+
if not os.path.exists(ckpt_path):
|
| 78 |
+
ckpt_path = cached_file(ckpt_path, "model.ckpt")
|
| 79 |
+
else:
|
| 80 |
+
ckpt_path = Path(ckpt_path)
|
| 81 |
+
|
| 82 |
+
checkpoint = torch.load(
|
| 83 |
+
ckpt_path,
|
| 84 |
+
map_location=lambda storage, loc: storage,
|
| 85 |
+
weights_only=False,
|
| 86 |
+
pickle_module=routed_pickle if route_pickle else None
|
| 87 |
+
)
|
| 88 |
+
tokenizer = checkpoint["hyper_parameters"]["tokenizer"]
|
| 89 |
+
model_args = checkpoint["hyper_parameters"]["args"]
|
| 90 |
+
state_dict = checkpoint["state_dict"]
|
| 91 |
+
non_compiled_state_dict = {}
|
| 92 |
+
for k, v in state_dict.items():
|
| 93 |
+
if k.startswith("model._orig_mod."):
|
| 94 |
+
non_compiled_state_dict["model." + k[16:]] = v
|
| 95 |
+
else:
|
| 96 |
+
non_compiled_state_dict[k] = v
|
| 97 |
+
|
| 98 |
+
model = LitOsuClassifier(model_args, tokenizer)
|
| 99 |
+
model.load_state_dict(non_compiled_state_dict)
|
| 100 |
+
model.eval().to(device)
|
| 101 |
+
return model, model_args, tokenizer
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def get_tokenizer(args: DictConfig) -> Tokenizer:
|
| 105 |
+
return Tokenizer(args)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_optimizer(parameters, args: DictConfig) -> Optimizer:
|
| 109 |
+
if args.optim.name == 'adamw':
|
| 110 |
+
optimizer = AdamW(
|
| 111 |
+
parameters,
|
| 112 |
+
lr=args.optim.base_lr,
|
| 113 |
+
)
|
| 114 |
+
else:
|
| 115 |
+
raise NotImplementedError
|
| 116 |
+
|
| 117 |
+
return optimizer
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def get_scheduler(optimizer: Optimizer, args: DictConfig, num_processes=1) -> LRScheduler:
|
| 121 |
+
scheduler_p1 = LinearLR(
|
| 122 |
+
optimizer,
|
| 123 |
+
start_factor=0.5,
|
| 124 |
+
end_factor=1,
|
| 125 |
+
total_iters=args.optim.warmup_steps * num_processes,
|
| 126 |
+
last_epoch=-1,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
scheduler_p2 = CosineAnnealingLR(
|
| 130 |
+
optimizer,
|
| 131 |
+
T_max=args.optim.total_steps * num_processes - args.optim.warmup_steps * num_processes,
|
| 132 |
+
eta_min=args.optim.final_cosine,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
scheduler = SequentialLR(
|
| 136 |
+
optimizer,
|
| 137 |
+
schedulers=[scheduler_p1, scheduler_p2],
|
| 138 |
+
milestones=[args.optim.warmup_steps * num_processes],
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
return scheduler
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def get_dataloaders(tokenizer: Tokenizer, args: DictConfig) -> tuple[DataLoader, DataLoader]:
|
| 145 |
+
parser = OsuParser(args, tokenizer)
|
| 146 |
+
dataset = {
|
| 147 |
+
"train": OrsDataset(
|
| 148 |
+
args.data,
|
| 149 |
+
parser,
|
| 150 |
+
tokenizer,
|
| 151 |
+
),
|
| 152 |
+
"test": OrsDataset(
|
| 153 |
+
args.data,
|
| 154 |
+
parser,
|
| 155 |
+
tokenizer,
|
| 156 |
+
test=True,
|
| 157 |
+
),
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
dataloaders = {}
|
| 161 |
+
for split in ["train", "test"]:
|
| 162 |
+
batch_size = args.optim.batch_size // args.optim.grad_acc
|
| 163 |
+
|
| 164 |
+
dataloaders[split] = DataLoader(
|
| 165 |
+
dataset[split],
|
| 166 |
+
batch_size=batch_size,
|
| 167 |
+
num_workers=args.dataloader.num_workers,
|
| 168 |
+
pin_memory=True,
|
| 169 |
+
drop_last=False,
|
| 170 |
+
persistent_workers=args.dataloader.num_workers > 0,
|
| 171 |
+
worker_init_fn=worker_init_fn,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
return dataloaders["train"], dataloaders["test"]
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def worker_init_fn(worker_id: int) -> None:
|
| 178 |
+
"""
|
| 179 |
+
Give each dataloader a unique slice of the full dataset.
|
| 180 |
+
"""
|
| 181 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 182 |
+
dataset = worker_info.dataset # the dataset copy in this worker process
|
| 183 |
+
overall_start = dataset.start
|
| 184 |
+
overall_end = dataset.end
|
| 185 |
+
# configure the dataset to only process the split workload
|
| 186 |
+
per_worker = int(
|
| 187 |
+
np.ceil((overall_end - overall_start) / float(worker_info.num_workers)),
|
| 188 |
+
)
|
| 189 |
+
dataset.start = overall_start + worker_id * per_worker
|
| 190 |
+
dataset.end = min(dataset.start + per_worker, overall_end)
|
classifier/libs/utils/routed_pickle.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
from typing import Dict
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Unpickler(pickle.Unpickler):
|
| 6 |
+
load_module_mapping: Dict[str, str] = {
|
| 7 |
+
'osuT5.tokenizer.event': 'osuT5.osuT5.event',
|
| 8 |
+
'libs.tokenizer.event': 'classifier.libs.tokenizer.event',
|
| 9 |
+
'libs.tokenizer.tokenizer': 'classifier.libs.tokenizer.tokenizer',
|
| 10 |
+
'osuT5.event': 'osuT5.osuT5.event',
|
| 11 |
+
'libs.event': 'classifier.libs.tokenizer.event',
|
| 12 |
+
'libs.tokenizer': 'classifier.libs.tokenizer.tokenizer',
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
def find_class(self, mod_name, name):
|
| 16 |
+
mod_name = self.load_module_mapping.get(mod_name, mod_name)
|
| 17 |
+
return super().find_class(mod_name, name)
|
classifier/test.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hydra
|
| 2 |
+
import lightning
|
| 3 |
+
import torch
|
| 4 |
+
from omegaconf import DictConfig
|
| 5 |
+
|
| 6 |
+
from classifier.libs.utils import load_ckpt
|
| 7 |
+
from libs import (
|
| 8 |
+
get_dataloaders,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
torch.set_float32_matmul_precision('high')
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@hydra.main(config_path="configs", config_name="train_v1", version_base="1.1")
|
| 15 |
+
def main(args: DictConfig):
|
| 16 |
+
model, model_args, tokenizer = load_ckpt(args.checkpoint_path, route_pickle=False)
|
| 17 |
+
|
| 18 |
+
_, val_dataloader = get_dataloaders(tokenizer, args)
|
| 19 |
+
|
| 20 |
+
if args.compile:
|
| 21 |
+
model.model = torch.compile(model.model)
|
| 22 |
+
|
| 23 |
+
trainer = lightning.Trainer(
|
| 24 |
+
accelerator=args.device,
|
| 25 |
+
precision=args.precision,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
trainer.test(model, val_dataloader)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
main()
|
classifier/train.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import hydra
|
| 4 |
+
import lightning
|
| 5 |
+
import torch
|
| 6 |
+
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
|
| 7 |
+
from lightning.pytorch.loggers import WandbLogger
|
| 8 |
+
from omegaconf import DictConfig
|
| 9 |
+
|
| 10 |
+
from libs import (
|
| 11 |
+
get_tokenizer,
|
| 12 |
+
get_dataloaders,
|
| 13 |
+
)
|
| 14 |
+
from libs.model.model import OsuClassifier
|
| 15 |
+
from libs.utils.model_utils import LitOsuClassifier
|
| 16 |
+
torch.set_float32_matmul_precision('high')
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_old_model(path: str, model: OsuClassifier):
|
| 20 |
+
ckpt_path = Path(path)
|
| 21 |
+
model_state = torch.load(ckpt_path / "pytorch_model.bin", weights_only=True)
|
| 22 |
+
|
| 23 |
+
ignore_list = [
|
| 24 |
+
"transformer.model.decoder.embed_tokens.weight",
|
| 25 |
+
"transformer.model.decoder.embed_positions.weight",
|
| 26 |
+
"decoder_embedder.weight",
|
| 27 |
+
"transformer.proj_out.weight",
|
| 28 |
+
"loss_fn.weight",
|
| 29 |
+
]
|
| 30 |
+
fixed_model_state = {}
|
| 31 |
+
|
| 32 |
+
for k, v in model_state.items():
|
| 33 |
+
if k in ignore_list:
|
| 34 |
+
continue
|
| 35 |
+
if k.startswith("transformer.model."):
|
| 36 |
+
fixed_model_state["transformer." + k[18:]] = v
|
| 37 |
+
else:
|
| 38 |
+
fixed_model_state[k] = v
|
| 39 |
+
|
| 40 |
+
model.load_state_dict(fixed_model_state, strict=False)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@hydra.main(config_path="configs", config_name="train_v1", version_base="1.1")
|
| 44 |
+
def main(args: DictConfig):
|
| 45 |
+
wandb_logger = WandbLogger(
|
| 46 |
+
project="osu-classifier",
|
| 47 |
+
entity="mappingtools",
|
| 48 |
+
job_type="training",
|
| 49 |
+
offline=args.logging.mode == "offline",
|
| 50 |
+
log_model="all" if args.logging.mode == "online" else False,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
tokenizer = get_tokenizer(args)
|
| 54 |
+
train_dataloader, val_dataloader = get_dataloaders(tokenizer, args)
|
| 55 |
+
|
| 56 |
+
model = LitOsuClassifier(args, tokenizer)
|
| 57 |
+
|
| 58 |
+
if args.pretrained_path:
|
| 59 |
+
load_old_model(args.pretrained_path, model.model)
|
| 60 |
+
|
| 61 |
+
if args.compile:
|
| 62 |
+
model.model = torch.compile(model.model)
|
| 63 |
+
|
| 64 |
+
checkpoint_callback = ModelCheckpoint(every_n_train_steps=args.checkpoint.every_steps, save_top_k=2, monitor="val_loss")
|
| 65 |
+
lr_monitor = LearningRateMonitor(logging_interval="step")
|
| 66 |
+
trainer = lightning.Trainer(
|
| 67 |
+
accelerator=args.device,
|
| 68 |
+
precision=args.precision,
|
| 69 |
+
logger=wandb_logger,
|
| 70 |
+
max_steps=args.optim.total_steps,
|
| 71 |
+
accumulate_grad_batches=args.optim.grad_acc,
|
| 72 |
+
gradient_clip_val=args.optim.grad_clip,
|
| 73 |
+
val_check_interval=args.eval.every_steps,
|
| 74 |
+
log_every_n_steps=args.logging.every_steps,
|
| 75 |
+
callbacks=[checkpoint_callback, lr_monitor],
|
| 76 |
+
)
|
| 77 |
+
trainer.fit(model, train_dataloader, val_dataloader)
|
| 78 |
+
trainer.save_checkpoint("final.ckpt")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
if __name__ == "__main__":
|
| 82 |
+
main()
|
cli_inference.sh
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Mapperatorinator CLI - Interactive Inference Script
|
| 4 |
+
# Based on web-ui.py functionality
|
| 5 |
+
|
| 6 |
+
set -e # Exit on error
|
| 7 |
+
|
| 8 |
+
# Colors for better UI
|
| 9 |
+
RED='\033[0;31m'
|
| 10 |
+
GREEN='\033[0;32m'
|
| 11 |
+
YELLOW='\033[1;33m'
|
| 12 |
+
BLUE='\033[0;34m'
|
| 13 |
+
PURPLE='\033[0;35m'
|
| 14 |
+
CYAN='\033[0;36m'
|
| 15 |
+
NC='\033[0m' # No Color
|
| 16 |
+
|
| 17 |
+
# Function to print colored text
|
| 18 |
+
print_color() {
|
| 19 |
+
local color=$1
|
| 20 |
+
local text=$2
|
| 21 |
+
echo -e "${color}${text}${NC}"
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
# Function to print section headers
|
| 25 |
+
print_header() {
|
| 26 |
+
echo
|
| 27 |
+
print_color $CYAN "======================================"
|
| 28 |
+
print_color $CYAN "$1"
|
| 29 |
+
print_color $CYAN "======================================"
|
| 30 |
+
echo
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
# Function to prompt for input with default value
|
| 34 |
+
prompt_input() {
|
| 35 |
+
local prompt=$1
|
| 36 |
+
local default=$2
|
| 37 |
+
local var_name=$3
|
| 38 |
+
|
| 39 |
+
if [ -n "$default" ]; then
|
| 40 |
+
read -e -p "$(print_color $GREEN "$prompt") [default: $default]: " input
|
| 41 |
+
if [ -z "$input" ]; then
|
| 42 |
+
input="$default"
|
| 43 |
+
fi
|
| 44 |
+
else
|
| 45 |
+
read -e -p "$(print_color $GREEN "$prompt"): " input
|
| 46 |
+
fi
|
| 47 |
+
|
| 48 |
+
eval "$var_name='$input'"
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
# Function to prompt for yes/no
|
| 52 |
+
prompt_yn() {
|
| 53 |
+
local prompt=$1
|
| 54 |
+
local default=$2
|
| 55 |
+
local var_name=$3
|
| 56 |
+
|
| 57 |
+
while true; do
|
| 58 |
+
if [ "$default" = "y" ]; then
|
| 59 |
+
read -p "$(print_color $GREEN "$prompt") [Y/n]: " yn
|
| 60 |
+
yn=${yn:-y}
|
| 61 |
+
else
|
| 62 |
+
read -p "$(print_color $GREEN "$prompt") [y/N]: " yn
|
| 63 |
+
yn=${yn:-n}
|
| 64 |
+
fi
|
| 65 |
+
|
| 66 |
+
case $yn in
|
| 67 |
+
[Yy]* ) eval "$var_name=true"; break;;
|
| 68 |
+
[Nn]* ) eval "$var_name=false"; break;;
|
| 69 |
+
* ) echo "Please answer yes or no.";;
|
| 70 |
+
esac
|
| 71 |
+
done
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
# Function to prompt for multiple choice
|
| 75 |
+
prompt_choice() {
|
| 76 |
+
local prompt=$1
|
| 77 |
+
local var_name=$2
|
| 78 |
+
shift 2
|
| 79 |
+
local options=("$@")
|
| 80 |
+
|
| 81 |
+
while true; do
|
| 82 |
+
print_color $GREEN "$prompt"
|
| 83 |
+
for i in "${!options[@]}"; do
|
| 84 |
+
echo " $((i+1))) ${options[i]}"
|
| 85 |
+
done
|
| 86 |
+
read -p "Select option (1-${#options[@]}): " choice
|
| 87 |
+
|
| 88 |
+
if [[ "$choice" =~ ^[0-9]+$ ]] && [ "$choice" -ge 1 ] && [ "$choice" -le "${#options[@]}" ]; then
|
| 89 |
+
eval "$var_name='${options[$((choice-1))]}'"
|
| 90 |
+
break
|
| 91 |
+
else
|
| 92 |
+
print_color $RED "Invalid choice. Please select 1-${#options[@]}."
|
| 93 |
+
fi
|
| 94 |
+
done
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
# Function to prompt for multiple selection using arrow keys and spacebar
|
| 98 |
+
prompt_multiselect() {
|
| 99 |
+
local prompt=$1
|
| 100 |
+
local var_name=$2
|
| 101 |
+
shift 2
|
| 102 |
+
local options=("$@")
|
| 103 |
+
local num_options=${#options[@]}
|
| 104 |
+
local selections=()
|
| 105 |
+
for (( i=0; i<num_options; i++ )); do
|
| 106 |
+
selections[i]=0
|
| 107 |
+
done
|
| 108 |
+
local current_idx=0
|
| 109 |
+
|
| 110 |
+
# Hide cursor for a cleaner UI
|
| 111 |
+
tput civis 2>/dev/null || true
|
| 112 |
+
# Ensure cursor is shown again on exit
|
| 113 |
+
trap 'tput cnorm; return' EXIT
|
| 114 |
+
|
| 115 |
+
# Initial draw
|
| 116 |
+
tput clear
|
| 117 |
+
|
| 118 |
+
while true; do
|
| 119 |
+
# Move cursor to top left
|
| 120 |
+
tput cup 0 0
|
| 121 |
+
|
| 122 |
+
echo -e "${GREEN}${prompt}${NC}"
|
| 123 |
+
echo "(Use UP/DOWN to navigate, SPACE to select/deselect, ENTER to confirm)"
|
| 124 |
+
|
| 125 |
+
for i in "${!options[@]}"; do
|
| 126 |
+
local checkbox="[ ]"
|
| 127 |
+
if [[ ${selections[i]} -eq 1 ]]; then
|
| 128 |
+
checkbox="[${GREEN}x${NC}]"
|
| 129 |
+
fi
|
| 130 |
+
|
| 131 |
+
if [ "$i" -eq "$current_idx" ]; then
|
| 132 |
+
echo -e " ${CYAN}> $checkbox ${options[i]}${NC}"
|
| 133 |
+
else
|
| 134 |
+
echo -e " $checkbox ${options[i]}"
|
| 135 |
+
fi
|
| 136 |
+
done
|
| 137 |
+
# Clear rest of the screen
|
| 138 |
+
tput ed
|
| 139 |
+
|
| 140 |
+
# Read a single keystroke.
|
| 141 |
+
# IFS= ensures space is read as a character, not a delimiter.
|
| 142 |
+
IFS= read -rsn1 key
|
| 143 |
+
|
| 144 |
+
# Handle escape sequences for arrow keys
|
| 145 |
+
if [[ "$key" == $'\e' ]]; then
|
| 146 |
+
read -rsn2 -t 0.1 key
|
| 147 |
+
fi
|
| 148 |
+
|
| 149 |
+
case "$key" in
|
| 150 |
+
'[A') # Up arrow
|
| 151 |
+
current_idx=$(( (current_idx - 1 + num_options) % num_options ))
|
| 152 |
+
;;
|
| 153 |
+
'[B') # Down arrow
|
| 154 |
+
current_idx=$(( (current_idx + 1) % num_options ))
|
| 155 |
+
;;
|
| 156 |
+
' ') # Space bar
|
| 157 |
+
if [[ ${selections[current_idx]} -eq 1 ]]; then
|
| 158 |
+
selections[current_idx]=0
|
| 159 |
+
else
|
| 160 |
+
selections[current_idx]=1
|
| 161 |
+
fi
|
| 162 |
+
;;
|
| 163 |
+
'') # Enter key
|
| 164 |
+
break
|
| 165 |
+
;;
|
| 166 |
+
esac
|
| 167 |
+
done
|
| 168 |
+
|
| 169 |
+
# Show cursor again and clear the trap
|
| 170 |
+
tput cnorm 2>/dev/null || true
|
| 171 |
+
trap - EXIT
|
| 172 |
+
|
| 173 |
+
# Go back to the bottom of the screen
|
| 174 |
+
tput cup $(tput lines) 0
|
| 175 |
+
clear # Clean up the interactive menu from screen
|
| 176 |
+
|
| 177 |
+
# Collect selected options
|
| 178 |
+
local selected_options=()
|
| 179 |
+
for i in "${!options[@]}"; do
|
| 180 |
+
if [[ ${selections[i]} -eq 1 ]]; then
|
| 181 |
+
selected_options+=("${options[i]}")
|
| 182 |
+
fi
|
| 183 |
+
done
|
| 184 |
+
|
| 185 |
+
# Format the result list for Hydra/Python: '["item1", "item2"]'
|
| 186 |
+
if [ ${#selected_options[@]} -gt 0 ]; then
|
| 187 |
+
local formatted_items=""
|
| 188 |
+
for item in "${selected_options[@]}"; do
|
| 189 |
+
if [ -n "$formatted_items" ]; then
|
| 190 |
+
# Each item is wrapped in double quotes
|
| 191 |
+
formatted_items="$formatted_items,\"$item\""
|
| 192 |
+
else
|
| 193 |
+
formatted_items="\"$item\""
|
| 194 |
+
fi
|
| 195 |
+
done
|
| 196 |
+
# The whole list is wrapped in brackets
|
| 197 |
+
eval "$var_name='[$formatted_items]'"
|
| 198 |
+
else
|
| 199 |
+
# Return an empty string if nothing is selected
|
| 200 |
+
eval "$var_name=''"
|
| 201 |
+
fi
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# Function to validate file path
|
| 206 |
+
validate_file() {
|
| 207 |
+
local file_path=$1
|
| 208 |
+
if [ ! -f "$file_path" ]; then
|
| 209 |
+
print_color $RED "File not found: $file_path"
|
| 210 |
+
return 1
|
| 211 |
+
fi
|
| 212 |
+
return 0
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
convert_path_if_needed() {
|
| 216 |
+
local input_path="$1"
|
| 217 |
+
|
| 218 |
+
# Return immediately if the path is empty
|
| 219 |
+
if [[ -z "$input_path" ]]; then
|
| 220 |
+
echo ""
|
| 221 |
+
return
|
| 222 |
+
fi
|
| 223 |
+
|
| 224 |
+
local uname_out
|
| 225 |
+
uname_out="$(uname -s)"
|
| 226 |
+
|
| 227 |
+
case "$uname_out" in
|
| 228 |
+
CYGWIN*|MINGW*|MSYS*)
|
| 229 |
+
cygpath -w "$input_path"
|
| 230 |
+
;;
|
| 231 |
+
*)
|
| 232 |
+
echo "$input_path"
|
| 233 |
+
;;
|
| 234 |
+
esac
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
# Main script starts here
|
| 238 |
+
print_color $PURPLE "╔═══════════════════════════════════════════╗"
|
| 239 |
+
print_color $PURPLE "║ Mapperatorinator CLI ║"
|
| 240 |
+
print_color $PURPLE "║ Interactive Inference Setup ║"
|
| 241 |
+
print_color $PURPLE "╚═══════════════════════════════════════════╝"
|
| 242 |
+
echo
|
| 243 |
+
|
| 244 |
+
# 2. Required Paths
|
| 245 |
+
print_header "Required Paths"
|
| 246 |
+
|
| 247 |
+
# Python Path
|
| 248 |
+
prompt_input "Python executable path" "python" python_executable
|
| 249 |
+
|
| 250 |
+
# Audio Path (Required)
|
| 251 |
+
while true; do
|
| 252 |
+
prompt_input "Audio file path (required)" "input/demo.mp3" audio_path
|
| 253 |
+
if [ -z "$audio_path" ]; then
|
| 254 |
+
print_color $RED "Audio path is required!"
|
| 255 |
+
continue
|
| 256 |
+
fi
|
| 257 |
+
if validate_file "$audio_path"; then
|
| 258 |
+
break
|
| 259 |
+
fi
|
| 260 |
+
done
|
| 261 |
+
|
| 262 |
+
# Output Path
|
| 263 |
+
prompt_input "Output directory path" "$(dirname "$audio_path")" output_path
|
| 264 |
+
|
| 265 |
+
# Beatmap Path (Optional)
|
| 266 |
+
prompt_input "Beatmap file path (optional, for in-context learning)" "" beatmap_path
|
| 267 |
+
if [ -n "$beatmap_path" ] && ! validate_file "$beatmap_path"; then
|
| 268 |
+
print_color $YELLOW "Warning: Beatmap file not found, continuing without it"
|
| 269 |
+
beatmap_path=""
|
| 270 |
+
fi
|
| 271 |
+
|
| 272 |
+
# Convert paths to Windows format if needed (for Cygwin/MinGW)
|
| 273 |
+
audio_path=$(convert_path_if_needed "$audio_path")
|
| 274 |
+
output_path=$(convert_path_if_needed "$output_path")
|
| 275 |
+
beatmap_path=$(convert_path_if_needed "$beatmap_path")
|
| 276 |
+
|
| 277 |
+
# 3. Basic Settings
|
| 278 |
+
print_header "Basic Settings"
|
| 279 |
+
|
| 280 |
+
# Model Selection
|
| 281 |
+
model_options=(
|
| 282 |
+
"v28:Mapperatorinator V28"
|
| 283 |
+
"v29:Mapperatorinator V29 (Supports gamemodes and descriptors)"
|
| 284 |
+
"v30:Mapperatorinator V30 (Best stable model)"
|
| 285 |
+
"v31:Mapperatorinator V31 (Slightly more accurate than V29)"
|
| 286 |
+
"beatheritage_v1:BeatHeritage V1 (Enhanced stability & quality)"
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
print_color $GREEN "Select Model:"
|
| 290 |
+
for i in "${!model_options[@]}"; do
|
| 291 |
+
IFS=':' read -r value desc <<< "${model_options[i]}"
|
| 292 |
+
echo " $((i+1))) $desc"
|
| 293 |
+
done
|
| 294 |
+
|
| 295 |
+
while true; do
|
| 296 |
+
read -p "Select model (1-${#model_options[@]}) [default: 5 - BeatHeritage V1]: " model_choice
|
| 297 |
+
model_choice=${model_choice:-5}
|
| 298 |
+
|
| 299 |
+
if [[ "$model_choice" =~ ^[1-5]$ ]]; then
|
| 300 |
+
IFS=':' read -r model_config model_desc <<< "${model_options[$((model_choice-1))]}"
|
| 301 |
+
print_color $BLUE "Selected: $model_desc"
|
| 302 |
+
break
|
| 303 |
+
else
|
| 304 |
+
print_color $RED "Invalid choice. Please select 1-${#model_options[@]}."
|
| 305 |
+
fi
|
| 306 |
+
done
|
| 307 |
+
|
| 308 |
+
# Game Mode (MODIFIED BLOCK)
|
| 309 |
+
gamemode_options=("osu!" "Taiko" "Catch" "Mania")
|
| 310 |
+
while true; do
|
| 311 |
+
print_color $GREEN "Game mode:"
|
| 312 |
+
for i in "${!gamemode_options[@]}"; do
|
| 313 |
+
echo " $i) ${gamemode_options[$i]}"
|
| 314 |
+
done
|
| 315 |
+
read -p "$(print_color $GREEN "Select option (0-3)") [default: 0]: " gamemode_input
|
| 316 |
+
# Set default value to 0 if input is empty
|
| 317 |
+
gamemode=${gamemode_input:-0}
|
| 318 |
+
|
| 319 |
+
if [[ "$gamemode" =~ ^[0-3]$ ]]; then
|
| 320 |
+
break
|
| 321 |
+
else
|
| 322 |
+
print_color $RED "Invalid choice. Please select a number between 0 and 3."
|
| 323 |
+
echo # Add a blank line for spacing before re-prompting
|
| 324 |
+
fi
|
| 325 |
+
done
|
| 326 |
+
|
| 327 |
+
# Difficulty
|
| 328 |
+
prompt_input "Difficulty (1.0-10.0)" "5.5" difficulty
|
| 329 |
+
|
| 330 |
+
# Year
|
| 331 |
+
# default is 2023, and 2007-2023 are valid years
|
| 332 |
+
prompt_input "Year" "2023" year
|
| 333 |
+
if ! [[ "$year" =~ ^(200[7-9]|201[0-9]|202[0-3])$ ]]; then
|
| 334 |
+
print_color $RED "Invalid year! Year must be between 2007 and 2023. Defaulting to 2023."
|
| 335 |
+
year=2023
|
| 336 |
+
fi
|
| 337 |
+
|
| 338 |
+
# 4. Advanced Settings (Optional)
|
| 339 |
+
print_header "Advanced Settings (Optional - Press Enter to skip)"
|
| 340 |
+
print_color $BLUE "Difficulty Settings:"
|
| 341 |
+
prompt_input "HP Drain Rate (0-10)" "" hp_drain_rate
|
| 342 |
+
prompt_input "Circle Size (0-10)" "" circle_size
|
| 343 |
+
prompt_input "Overall Difficulty (0-10)" "" overall_difficulty
|
| 344 |
+
prompt_input "Approach Rate (0-10)" "" approach_rate
|
| 345 |
+
print_color $BLUE "Slider Settings:"
|
| 346 |
+
prompt_input "Slider Multiplier" "" slider_multiplier
|
| 347 |
+
prompt_input "Slider Tick Rate" "" slider_tick_rate
|
| 348 |
+
if [ "$gamemode" -eq 3 ]; then
|
| 349 |
+
print_color $BLUE "Mania Settings:"
|
| 350 |
+
prompt_input "Key Count" "" keycount
|
| 351 |
+
prompt_input "Hold Note Ratio (0-1)" "" hold_note_ratio
|
| 352 |
+
prompt_input "Scroll Speed Ratio" "" scroll_speed_ratio
|
| 353 |
+
fi
|
| 354 |
+
print_color $BLUE "Generation Settings:"
|
| 355 |
+
prompt_input "CFG Scale (1-20)" "" cfg_scale
|
| 356 |
+
prompt_input "Temperature (0-2)" "" temperature
|
| 357 |
+
prompt_input "Top P (0-1)" "" top_p
|
| 358 |
+
prompt_input "Seed (random if empty)" "" seed
|
| 359 |
+
prompt_input "Mapper ID" "" mapper_id
|
| 360 |
+
print_color $BLUE "Timing Settings:"
|
| 361 |
+
prompt_input "Start Time (seconds)" "" start_time
|
| 362 |
+
prompt_input "End Time (seconds)" "" end_time
|
| 363 |
+
|
| 364 |
+
# 5. Boolean Options
|
| 365 |
+
print_header "Export & Processing Options"
|
| 366 |
+
prompt_yn "Export as .osz file?" "n" export_osz
|
| 367 |
+
prompt_yn "Add to existing beatmap?" "n" add_to_beatmap
|
| 368 |
+
prompt_yn "Add hitsounds?" "n" hitsounded
|
| 369 |
+
prompt_yn "Use super timing analysis?" "n" super_timing
|
| 370 |
+
|
| 371 |
+
# 6. Descriptors
|
| 372 |
+
print_header "Style Descriptors"
|
| 373 |
+
|
| 374 |
+
# Positive descriptors with interactive multi-select
|
| 375 |
+
descriptor_options=("jump aim" "stream" "tech" "aim" "speed" "flow" "clean" "complex" "simple" "modern" "classic" "spaced" "stacked")
|
| 376 |
+
prompt_multiselect "Positive descriptors (describe desired mapping style):" descriptors "${descriptor_options[@]}"
|
| 377 |
+
|
| 378 |
+
# Negative descriptors with interactive multi-select
|
| 379 |
+
prompt_multiselect "Negative descriptors (styles to avoid):" negative_descriptors "${descriptor_options[@]}"
|
| 380 |
+
|
| 381 |
+
# In-context options (only if beatmap is provided)
|
| 382 |
+
if [ -n "$beatmap_path" ]; then
|
| 383 |
+
print_header "In-Context Learning Options"
|
| 384 |
+
context_options_list=("timing" "patterns" "structure" "style")
|
| 385 |
+
prompt_multiselect "In-context learning aspects:" in_context_options "${context_options_list[@]}"
|
| 386 |
+
fi
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
# 7. Build and Execute Command
|
| 390 |
+
print_header "Command Generation"
|
| 391 |
+
|
| 392 |
+
# Start building the command
|
| 393 |
+
cmd_args=("$python_executable" "inference.py" "-cn" "$model_config")
|
| 394 |
+
|
| 395 |
+
# Helper function to add argument. Wraps value in single quotes.
|
| 396 |
+
add_arg() {
|
| 397 |
+
local key=$1
|
| 398 |
+
local value=$2
|
| 399 |
+
if [ -n "$value" ]; then
|
| 400 |
+
# This format 'key=value' is robust for Hydra, even with complex values
|
| 401 |
+
# like lists represented as strings: descriptors='["item1", "item2"]'
|
| 402 |
+
cmd_args+=("${key}=${value}") # Removed extra quotes for direct execution
|
| 403 |
+
fi
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
# Helper function to add boolean argument
|
| 407 |
+
add_bool_arg() {
|
| 408 |
+
local key=$1
|
| 409 |
+
local value=$2
|
| 410 |
+
if [ "$value" = "true" ]; then
|
| 411 |
+
cmd_args+=("${key}=true")
|
| 412 |
+
else
|
| 413 |
+
cmd_args+=("${key}=false")
|
| 414 |
+
fi
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
# Add all arguments
|
| 418 |
+
add_arg "audio_path" "'$audio_path'"
|
| 419 |
+
add_arg "output_path" "'$output_path'"
|
| 420 |
+
add_arg "beatmap_path" "'$beatmap_path'"
|
| 421 |
+
add_arg "gamemode" "$gamemode"
|
| 422 |
+
add_arg "difficulty" "$difficulty"
|
| 423 |
+
add_arg "year" "$year"
|
| 424 |
+
|
| 425 |
+
# Optional numeric parameters
|
| 426 |
+
add_arg "hp_drain_rate" "$hp_drain_rate"
|
| 427 |
+
add_arg "circle_size" "$circle_size"
|
| 428 |
+
add_arg "overall_difficulty" "$overall_difficulty"
|
| 429 |
+
add_arg "approach_rate" "$approach_rate"
|
| 430 |
+
add_arg "slider_multiplier" "$slider_multiplier"
|
| 431 |
+
add_arg "slider_tick_rate" "$slider_tick_rate"
|
| 432 |
+
add_arg "keycount" "$keycount"
|
| 433 |
+
add_arg "hold_note_ratio" "$hold_note_ratio"
|
| 434 |
+
add_arg "scroll_speed_ratio" "$scroll_speed_ratio"
|
| 435 |
+
add_arg "cfg_scale" "$cfg_scale"
|
| 436 |
+
add_arg "temperature" "$temperature"
|
| 437 |
+
add_arg "top_p" "$top_p"
|
| 438 |
+
add_arg "seed" "$seed"
|
| 439 |
+
add_arg "mapper_id" "$mapper_id"
|
| 440 |
+
add_arg "start_time" "$start_time"
|
| 441 |
+
add_arg "end_time" "$end_time"
|
| 442 |
+
|
| 443 |
+
# List parameters (now correctly quoted)
|
| 444 |
+
add_arg "descriptors" "$descriptors"
|
| 445 |
+
add_arg "negative_descriptors" "$negative_descriptors"
|
| 446 |
+
add_arg "in_context" "$in_context_options"
|
| 447 |
+
|
| 448 |
+
# Boolean parameters
|
| 449 |
+
add_bool_arg "export_osz" "$export_osz"
|
| 450 |
+
add_bool_arg "add_to_beatmap" "$add_to_beatmap"
|
| 451 |
+
add_bool_arg "hitsounded" "$hitsounded"
|
| 452 |
+
add_bool_arg "super_timing" "$super_timing"
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
# Display the command
|
| 456 |
+
print_color $YELLOW "Generated command:"
|
| 457 |
+
echo
|
| 458 |
+
# Use printf for safer printing of arguments
|
| 459 |
+
printf "%s " "${cmd_args[@]}"
|
| 460 |
+
echo
|
| 461 |
+
echo
|
| 462 |
+
|
| 463 |
+
# Ask for confirmation
|
| 464 |
+
prompt_yn "Execute this command?" "y" execute_cmd
|
| 465 |
+
|
| 466 |
+
if [ "$execute_cmd" = "true" ]; then
|
| 467 |
+
print_header "Executing Inference"
|
| 468 |
+
print_color $GREEN "Starting inference process..."
|
| 469 |
+
echo
|
| 470 |
+
|
| 471 |
+
# Execute the command by expanding the array. No need for eval.
|
| 472 |
+
"${cmd_args[@]}"
|
| 473 |
+
|
| 474 |
+
exit_code=$?
|
| 475 |
+
echo
|
| 476 |
+
if [ $exit_code -eq 0 ]; then
|
| 477 |
+
print_color $GREEN "✓ Inference completed successfully!"
|
| 478 |
+
else
|
| 479 |
+
print_color $RED "✗ Inference failed with exit code: $exit_code"
|
| 480 |
+
fi
|
| 481 |
+
else
|
| 482 |
+
print_color $YELLOW "Command generation cancelled."
|
| 483 |
+
echo
|
| 484 |
+
print_color $BLUE "You can copy and run the command manually:"
|
| 485 |
+
# Use printf for safer printing of arguments
|
| 486 |
+
printf "%s " "${cmd_args[@]}"
|
| 487 |
+
echo
|
| 488 |
+
fi
|
| 489 |
+
|
| 490 |
+
echo
|
| 491 |
+
print_color $PURPLE "Thank you for using Mapperatorinator CLI!"
|
colab/beatheritage_v1_inference.ipynb
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"metadata": {},
|
| 5 |
+
"cell_type": "markdown",
|
| 6 |
+
"source": [
|
| 7 |
+
"<a href=\"https://colab.research.google.com/github/hongminh54/BeatHeritage/blob/main/colab/beatheritage_v1_inference.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"# BeatHeritage V1 - Beatmap Generator\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"An enhanced AI model for generating osu! beatmaps with improved stability and quality control.\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"### Instructions:\n",
|
| 15 |
+
"1. **Read and accept the rules** by clicking the checkbox in the first cell\n",
|
| 16 |
+
"2. **Ensure GPU runtime**: Go to __Runtime → Change Runtime Type → GPU__\n",
|
| 17 |
+
"3. **Execute cells in order**: Click ▶️ on each cell sequentially\n",
|
| 18 |
+
"4. **Upload your audio**: Choose an MP3/OGG file when prompted\n",
|
| 19 |
+
"5. **Configure parameters**: Adjust settings to your preference\n",
|
| 20 |
+
"6. **Generate beatmap**: Run the generation cell and wait for results\n"
|
| 21 |
+
]
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "code",
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"source": [
|
| 27 |
+
"#@title 🚀 Setup Environment { display-mode: \"form\" }\n",
|
| 28 |
+
"#@markdown ### ⚠️ Important: Please use this tool responsibly\n",
|
| 29 |
+
"#@markdown - Always disclose AI usage in your beatmap descriptions\n",
|
| 30 |
+
"#@markdown - Respect the original music artists and mappers\n",
|
| 31 |
+
"#@markdown - This tool is for educational and creative purposes\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"i_accept_the_rules = False #@param {type:\"boolean\"}\n",
|
| 34 |
+
"#@markdown ☑️ **I accept the rules and will use this tool responsibly**\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"import os\n",
|
| 37 |
+
"import sys\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"if not i_accept_the_rules:\n",
|
| 40 |
+
" raise ValueError(\"Please read and accept the rules before proceeding!\")\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"print(\"Installing BeatHeritage...\")\n",
|
| 43 |
+
"print(\"=\"*50)\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"# Clone repository if not exists\n",
|
| 46 |
+
"if not os.path.exists('/content/BeatHeritage'):\n",
|
| 47 |
+
" !git clone -q https://github.com/hongminh54/BeatHeritage.git\n",
|
| 48 |
+
" print(\"✅ Repository cloned\")\n",
|
| 49 |
+
"else:\n",
|
| 50 |
+
" print(\"✅ Repository already exists\")\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"%cd /content/BeatHeritage\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"# Install dependencies\n",
|
| 55 |
+
"print(\"\\nInstalling dependencies...\")\n",
|
| 56 |
+
"!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n",
|
| 57 |
+
"!pip install -q -r requirements.txt\n",
|
| 58 |
+
"!apt-get install -y ffmpeg > /dev/null 2>&1\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"print(\"\\nSetup complete!\")\n",
|
| 61 |
+
"\n",
|
| 62 |
+
"# Import required libraries\n",
|
| 63 |
+
"import warnings\n",
|
| 64 |
+
"warnings.filterwarnings('ignore')\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"import torch\n",
|
| 67 |
+
"from google.colab import files\n",
|
| 68 |
+
"from IPython.display import display, HTML, Audio\n",
|
| 69 |
+
"from pathlib import Path\n",
|
| 70 |
+
"import json\n",
|
| 71 |
+
"import shlex\n",
|
| 72 |
+
"import subprocess\n",
|
| 73 |
+
"from datetime import datetime\n",
|
| 74 |
+
"import zipfile\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"# Check GPU availability\n",
|
| 77 |
+
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
| 78 |
+
"print(f\"\\nUsing device: {device}\")\n",
|
| 79 |
+
"if device == 'cuda':\n",
|
| 80 |
+
" gpu_name = torch.cuda.get_device_name(0)\n",
|
| 81 |
+
" gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3\n",
|
| 82 |
+
" print(f\"GPU: {gpu_name}\")\n",
|
| 83 |
+
" print(f\"Memory: {gpu_memory:.1f} GB\")\n",
|
| 84 |
+
"else:\n",
|
| 85 |
+
" print(\"No GPU detected! Generation will be VERY slow.\")\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"# Initialize global variables\n",
|
| 88 |
+
"audio_path = \"\"\n",
|
| 89 |
+
"output_path = \"/content/BeatHeritage/output\"\n",
|
| 90 |
+
"os.makedirs(output_path, exist_ok=True)"
|
| 91 |
+
],
|
| 92 |
+
"outputs": [],
|
| 93 |
+
"execution_count": null
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"cell_type": "code",
|
| 97 |
+
"metadata": {},
|
| 98 |
+
"source": [
|
| 99 |
+
"#@title 🎵 Upload Audio File { display-mode: \"form\" }\n",
|
| 100 |
+
"#@markdown Upload your audio file (MP3, OGG, or WAV format)\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"def upload_and_validate_audio():\n",
|
| 103 |
+
" \"\"\"Upload and validate audio file with proper error handling\"\"\"\n",
|
| 104 |
+
" global audio_path\n",
|
| 105 |
+
" \n",
|
| 106 |
+
" print(\"Please select an audio file to upload...\")\n",
|
| 107 |
+
" uploaded = files.upload()\n",
|
| 108 |
+
" \n",
|
| 109 |
+
" if not uploaded:\n",
|
| 110 |
+
" print(\"No file uploaded\")\n",
|
| 111 |
+
" return None\n",
|
| 112 |
+
" \n",
|
| 113 |
+
" # Get the first uploaded file\n",
|
| 114 |
+
" original_filename = list(uploaded.keys())[0]\n",
|
| 115 |
+
" \n",
|
| 116 |
+
" # Clean filename - remove special characters and spaces\n",
|
| 117 |
+
" import re\n",
|
| 118 |
+
" clean_filename = re.sub(r'[^a-zA-Z0-9._-]', '_', original_filename)\n",
|
| 119 |
+
" clean_filename = clean_filename.replace(' ', '_')\n",
|
| 120 |
+
" \n",
|
| 121 |
+
" # Ensure proper extension\n",
|
| 122 |
+
" if not any(clean_filename.lower().endswith(ext) for ext in ['.mp3', '.ogg', '.wav']):\n",
|
| 123 |
+
" print(f\"Invalid file format: {original_filename}\")\n",
|
| 124 |
+
" print(\"Please upload an MP3, OGG, or WAV file\")\n",
|
| 125 |
+
" return None\n",
|
| 126 |
+
" \n",
|
| 127 |
+
" # Save with cleaned filename\n",
|
| 128 |
+
" audio_path = f'/content/BeatHeritage/{clean_filename}'\n",
|
| 129 |
+
" \n",
|
| 130 |
+
" # Write the uploaded content to the new path\n",
|
| 131 |
+
" with open(audio_path, 'wb') as f:\n",
|
| 132 |
+
" f.write(uploaded[original_filename])\n",
|
| 133 |
+
" \n",
|
| 134 |
+
" print(f\"Audio uploaded successfully!\")\n",
|
| 135 |
+
" print(f\"Original: {original_filename}\")\n",
|
| 136 |
+
" print(f\"Saved as: {clean_filename}\")\n",
|
| 137 |
+
" print(f\"Path: {audio_path}\")\n",
|
| 138 |
+
" \n",
|
| 139 |
+
" # Display audio player\n",
|
| 140 |
+
" display(Audio(audio_path))\n",
|
| 141 |
+
" \n",
|
| 142 |
+
" return audio_path\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"# Upload audio\n",
|
| 145 |
+
"audio_path = upload_and_validate_audio()\n",
|
| 146 |
+
"\n",
|
| 147 |
+
"if not audio_path:\n",
|
| 148 |
+
" print(\"\\n⚠Please run this cell again and upload a valid audio file\")"
|
| 149 |
+
],
|
| 150 |
+
"outputs": [],
|
| 151 |
+
"execution_count": null
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"cell_type": "code",
|
| 155 |
+
"metadata": {},
|
| 156 |
+
"source": [
|
| 157 |
+
"#@title ⚙️ Configure Generation Parameters { display-mode: \"form\" }\n",
|
| 158 |
+
"\n",
|
| 159 |
+
"#@markdown ### 🎯 Basic Settings\n",
|
| 160 |
+
"#@markdown ---\n",
|
| 161 |
+
"#@markdown Choose the AI model version to use:\n",
|
| 162 |
+
"model_version = \"BeatHeritage V1 (Enhanced)\" #@param [\"BeatHeritage V1 (Enhanced)\", \"Mapperatorinator V30\", \"Mapperatorinator V29\", \"Mapperatorinator V28\"]\n",
|
| 163 |
+
"\n",
|
| 164 |
+
"#@markdown Select the game mode for your beatmap:\n",
|
| 165 |
+
"gamemode = \"Standard\" #@param [\"Standard\", \"Taiko\", \"Catch the Beat\", \"Mania\"]\n",
|
| 166 |
+
"\n",
|
| 167 |
+
"#@markdown Target difficulty (★ rating):\n",
|
| 168 |
+
"difficulty = 5.5 #@param {type:\"slider\", min:1, max:10, step:0.1}\n",
|
| 169 |
+
"\n",
|
| 170 |
+
"#@markdown ### 🎨 Style Configuration\n",
|
| 171 |
+
"#@markdown ---\n",
|
| 172 |
+
"#@markdown Primary mapping style descriptor:\n",
|
| 173 |
+
"descriptor_1 = \"clean\" #@param [\"clean\", \"tech\", \"jump aim\", \"stream\", \"aim\", \"speed\", \"flow\", \"complex\", \"simple\", \"modern\", \"classic\", \"slider tech\", \"alt\", \"precision\", \"stamina\"]\n",
|
| 174 |
+
"\n",
|
| 175 |
+
"#@markdown Secondary style descriptor (optional):\n",
|
| 176 |
+
"descriptor_2 = \"\" #@param [\"\", \"clean\", \"tech\", \"jump aim\", \"stream\", \"aim\", \"speed\", \"flow\", \"complex\", \"simple\", \"modern\", \"classic\", \"slider tech\", \"alt\", \"precision\", \"stamina\"]\n",
|
| 177 |
+
"\n",
|
| 178 |
+
"#@markdown ### 🔧 Advanced Parameters\n",
|
| 179 |
+
"#@markdown ---\n",
|
| 180 |
+
"#@markdown Generation temperature (lower = more conservative):\n",
|
| 181 |
+
"temperature = 0.85 #@param {type:\"slider\", min:0.1, max:2.0, step:0.05}\n",
|
| 182 |
+
"\n",
|
| 183 |
+
"#@markdown Top-p sampling (nucleus sampling):\n",
|
| 184 |
+
"top_p = 0.92 #@param {type:\"slider\", min:0.1, max:1.0, step:0.01}\n",
|
| 185 |
+
"\n",
|
| 186 |
+
"#@markdown Classifier-free guidance scale:\n",
|
| 187 |
+
"cfg_scale = 7.5 #@param {type:\"slider\", min:1.0, max:20.0, step:0.5}\n",
|
| 188 |
+
"\n",
|
| 189 |
+
"#@markdown ### 📊 Quality Control (BeatHeritage V1)\n",
|
| 190 |
+
"#@markdown ---\n",
|
| 191 |
+
"enable_auto_correction = True #@param {type:\"boolean\"}\n",
|
| 192 |
+
"enable_flow_optimization = True #@param {type:\"boolean\"}\n",
|
| 193 |
+
"enable_pattern_variety = True #@param {type:\"boolean\"}\n",
|
| 194 |
+
"\n",
|
| 195 |
+
"#@markdown ### 🎯 Export Options\n",
|
| 196 |
+
"#@markdown ---\n",
|
| 197 |
+
"super_timing = False #@param {type:\"boolean\"}\n",
|
| 198 |
+
"#@markdown Enable for songs with variable BPM (slower generation)\n",
|
| 199 |
+
"\n",
|
| 200 |
+
"export_osz = True #@param {type:\"boolean\"}\n",
|
| 201 |
+
"#@markdown Export as .osz package (includes audio)\n",
|
| 202 |
+
"\n",
|
| 203 |
+
"# Map model names to config names\n",
|
| 204 |
+
"model_configs = {\n",
|
| 205 |
+
" \"BeatHeritage V1 (Enhanced)\": \"beatheritage_v1\",\n",
|
| 206 |
+
" \"Mapperatorinator V30\": \"v30\",\n",
|
| 207 |
+
" \"Mapperatorinator V29\": \"v29\",\n",
|
| 208 |
+
" \"Mapperatorinator V28\": \"v28\"\n",
|
| 209 |
+
"}\n",
|
| 210 |
+
"\n",
|
| 211 |
+
"# Map gamemode names to indices\n",
|
| 212 |
+
"gamemode_indices = {\n",
|
| 213 |
+
" \"Standard\": 0,\n",
|
| 214 |
+
" \"Taiko\": 1,\n",
|
| 215 |
+
" \"Catch the Beat\": 2,\n",
|
| 216 |
+
" \"Mania\": 3\n",
|
| 217 |
+
"}\n",
|
| 218 |
+
"\n",
|
| 219 |
+
"selected_model = model_configs[model_version]\n",
|
| 220 |
+
"selected_gamemode = gamemode_indices[gamemode]\n",
|
| 221 |
+
"\n",
|
| 222 |
+
"# Build descriptor list\n",
|
| 223 |
+
"descriptors = [d for d in [descriptor_1, descriptor_2] if d]\n",
|
| 224 |
+
"\n",
|
| 225 |
+
"# Display configuration summary\n",
|
| 226 |
+
"print(\"Configuration Summary\")\n",
|
| 227 |
+
"print(\"=\"*50)\n",
|
| 228 |
+
"print(f\"Model: {model_version}\")\n",
|
| 229 |
+
"print(f\"Game Mode: {gamemode}\")\n",
|
| 230 |
+
"print(f\"Difficulty: {difficulty}★\")\n",
|
| 231 |
+
"print(f\"Style: {', '.join(descriptors) if descriptors else 'Default'}\")\n",
|
| 232 |
+
"print(f\"Temperature: {temperature}\")\n",
|
| 233 |
+
"print(f\"Top-p: {top_p}\")\n",
|
| 234 |
+
"print(f\"CFG Scale: {cfg_scale}\")\n",
|
| 235 |
+
"\n",
|
| 236 |
+
"if selected_model == \"beatheritage_v1\":\n",
|
| 237 |
+
" print(\"\\nBeatHeritage V1 Features:\")\n",
|
| 238 |
+
" if enable_auto_correction:\n",
|
| 239 |
+
" print(\" ✓ Auto-correction enabled\")\n",
|
| 240 |
+
" if enable_flow_optimization:\n",
|
| 241 |
+
" print(\" ✓ Flow optimization enabled\")\n",
|
| 242 |
+
" if enable_pattern_variety:\n",
|
| 243 |
+
" print(\" ✓ Pattern variety enabled\")\n",
|
| 244 |
+
"\n",
|
| 245 |
+
"if super_timing:\n",
|
| 246 |
+
" print(\"\\nSuper timing enabled (for variable BPM)\")\n",
|
| 247 |
+
"\n",
|
| 248 |
+
"print(\"\\nConfiguration ready!\")"
|
| 249 |
+
],
|
| 250 |
+
"outputs": [],
|
| 251 |
+
"execution_count": null
|
| 252 |
+
},
|
| 253 |
+
{
|
| 254 |
+
"cell_type": "code",
|
| 255 |
+
"metadata": {},
|
| 256 |
+
"source": [
|
| 257 |
+
"#@title 🎮 Generate Beatmap { display-mode: \"form\" }\n",
|
| 258 |
+
"#@markdown Click the play button to start generation. This may take a few minutes depending on song length.\n",
|
| 259 |
+
"\n",
|
| 260 |
+
"def generate_beatmap():\n",
|
| 261 |
+
" \"\"\"Generate beatmap with proper error handling and progress tracking\"\"\"\n",
|
| 262 |
+
" \n",
|
| 263 |
+
" if not audio_path or not os.path.exists(audio_path):\n",
|
| 264 |
+
" print(\"Error: No audio file found!\")\n",
|
| 265 |
+
" print(\"Please upload an audio file first.\")\n",
|
| 266 |
+
" return None\n",
|
| 267 |
+
" \n",
|
| 268 |
+
" print(\"Starting beatmap generation...\")\n",
|
| 269 |
+
" print(\"=\"*50)\n",
|
| 270 |
+
" print(f\"Audio: {os.path.basename(audio_path)}\")\n",
|
| 271 |
+
" print(f\"Model: {model_version}\")\n",
|
| 272 |
+
" print(f\"Mode: {gamemode}\")\n",
|
| 273 |
+
" print(f\"Difficulty: {difficulty}★\")\n",
|
| 274 |
+
" print(\"=\"*50)\n",
|
| 275 |
+
" print()\n",
|
| 276 |
+
" \n",
|
| 277 |
+
" # Build command with proper escaping\n",
|
| 278 |
+
" cmd = [\n",
|
| 279 |
+
" 'python', 'inference.py',\n",
|
| 280 |
+
" '-cn', selected_model,\n",
|
| 281 |
+
" f'audio_path={shlex.quote(audio_path)}',\n",
|
| 282 |
+
" f'output_path={shlex.quote(output_path)}',\n",
|
| 283 |
+
" f'gamemode={selected_gamemode}',\n",
|
| 284 |
+
" f'difficulty={difficulty}',\n",
|
| 285 |
+
" f'temperature={temperature}',\n",
|
| 286 |
+
" f'top_p={top_p}',\n",
|
| 287 |
+
" f'cfg_scale={cfg_scale}',\n",
|
| 288 |
+
" f'super_timing={str(super_timing).lower()}',\n",
|
| 289 |
+
" f'export_osz={str(export_osz).lower()}',\n",
|
| 290 |
+
" ]\n",
|
| 291 |
+
" \n",
|
| 292 |
+
" # Add descriptors if specified\n",
|
| 293 |
+
" if descriptors:\n",
|
| 294 |
+
" desc_str = json.dumps(descriptors)\n",
|
| 295 |
+
" cmd.append(f'descriptors={shlex.quote(desc_str)}')\n",
|
| 296 |
+
" \n",
|
| 297 |
+
" # Add BeatHeritage V1 specific features\n",
|
| 298 |
+
" if selected_model == \"beatheritage_v1\":\n",
|
| 299 |
+
" if enable_auto_correction:\n",
|
| 300 |
+
" cmd.append('quality_control.enable_auto_correction=true')\n",
|
| 301 |
+
" if enable_flow_optimization:\n",
|
| 302 |
+
" cmd.append('quality_control.enable_flow_optimization=true')\n",
|
| 303 |
+
" if enable_pattern_variety:\n",
|
| 304 |
+
" cmd.append('advanced_features.enable_pattern_variety=true')\n",
|
| 305 |
+
" \n",
|
| 306 |
+
" # Always enable these for V1\n",
|
| 307 |
+
" cmd.extend([\n",
|
| 308 |
+
" 'advanced_features.enable_context_aware_generation=true',\n",
|
| 309 |
+
" 'advanced_features.enable_style_preservation=true',\n",
|
| 310 |
+
" 'generate_positions=true',\n",
|
| 311 |
+
" 'position_refinement=true'\n",
|
| 312 |
+
" ])\n",
|
| 313 |
+
" \n",
|
| 314 |
+
" # Execute command\n",
|
| 315 |
+
" try:\n",
|
| 316 |
+
" print(\"⏳ Generating beatmap... (this may take several minutes)\\n\")\n",
|
| 317 |
+
" \n",
|
| 318 |
+
" # Run the command\n",
|
| 319 |
+
" process = subprocess.Popen(\n",
|
| 320 |
+
" cmd,\n",
|
| 321 |
+
" stdout=subprocess.PIPE,\n",
|
| 322 |
+
" stderr=subprocess.STDOUT,\n",
|
| 323 |
+
" text=True,\n",
|
| 324 |
+
" bufsize=1,\n",
|
| 325 |
+
" universal_newlines=True\n",
|
| 326 |
+
" )\n",
|
| 327 |
+
" \n",
|
| 328 |
+
" # Stream output in real-time\n",
|
| 329 |
+
" for line in process.stdout:\n",
|
| 330 |
+
" print(line, end='')\n",
|
| 331 |
+
" \n",
|
| 332 |
+
" # Wait for completion\n",
|
| 333 |
+
" return_code = process.wait()\n",
|
| 334 |
+
" \n",
|
| 335 |
+
" if return_code == 0:\n",
|
| 336 |
+
" print(\"\\n\" + \"=\"*50)\n",
|
| 337 |
+
" print(\"Beatmap generation complete!\")\n",
|
| 338 |
+
" \n",
|
| 339 |
+
" # List generated files\n",
|
| 340 |
+
" generated_files = list(Path(output_path).glob('*'))\n",
|
| 341 |
+
" if generated_files:\n",
|
| 342 |
+
" print(f\"\\nGenerated {len(generated_files)} file(s):\")\n",
|
| 343 |
+
" for file in generated_files:\n",
|
| 344 |
+
" size_mb = file.stat().st_size / (1024 * 1024)\n",
|
| 345 |
+
" print(f\" • {file.name} ({size_mb:.2f} MB)\")\n",
|
| 346 |
+
" \n",
|
| 347 |
+
" return generated_files\n",
|
| 348 |
+
" else:\n",
|
| 349 |
+
" print(f\"\\nGeneration failed with error code: {return_code}\")\n",
|
| 350 |
+
" return None\n",
|
| 351 |
+
" \n",
|
| 352 |
+
" except Exception as e:\n",
|
| 353 |
+
" print(f\"\\nError during generation: {str(e)}\")\n",
|
| 354 |
+
" print(\"\\nTroubleshooting tips:\")\n",
|
| 355 |
+
" print(\"1. Ensure the audio file is valid\")\n",
|
| 356 |
+
" print(\"2. Check if GPU memory is sufficient\")\n",
|
| 357 |
+
" print(\"3. Try reducing temperature or cfg_scale\")\n",
|
| 358 |
+
" print(\"4. Disable super_timing if enabled\")\n",
|
| 359 |
+
" return None\n",
|
| 360 |
+
"\n",
|
| 361 |
+
"# Run generation\n",
|
| 362 |
+
"generated_files = generate_beatmap()"
|
| 363 |
+
],
|
| 364 |
+
"outputs": [],
|
| 365 |
+
"execution_count": null
|
| 366 |
+
},
|
| 367 |
+
{
|
| 368 |
+
"cell_type": "code",
|
| 369 |
+
"metadata": {},
|
| 370 |
+
"source": [
|
| 371 |
+
"#@title 📥 Download Generated Files { display-mode: \"form\" }\n",
|
| 372 |
+
"#@markdown Download your generated beatmap files\n",
|
| 373 |
+
"\n",
|
| 374 |
+
"def download_results():\n",
|
| 375 |
+
" \"\"\"Package and download generated beatmap files\"\"\"\n",
|
| 376 |
+
" \n",
|
| 377 |
+
" output_files = list(Path(output_path).glob('*'))\n",
|
| 378 |
+
" \n",
|
| 379 |
+
" if not output_files:\n",
|
| 380 |
+
" print(\"No files to download\")\n",
|
| 381 |
+
" print(\"Please generate a beatmap first.\")\n",
|
| 382 |
+
" return\n",
|
| 383 |
+
" \n",
|
| 384 |
+
" print(\"Preparing files for download...\")\n",
|
| 385 |
+
" print(\"=\"*50)\n",
|
| 386 |
+
" \n",
|
| 387 |
+
" # Create timestamp for unique naming\n",
|
| 388 |
+
" timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')\n",
|
| 389 |
+
" \n",
|
| 390 |
+
" # Check if we have .osz files\n",
|
| 391 |
+
" osz_files = [f for f in output_files if f.suffix == '.osz']\n",
|
| 392 |
+
" osu_files = [f for f in output_files if f.suffix == '.osu']\n",
|
| 393 |
+
" \n",
|
| 394 |
+
" # Download .osz files directly if available\n",
|
| 395 |
+
" if osz_files:\n",
|
| 396 |
+
" for osz_file in osz_files:\n",
|
| 397 |
+
" print(f\"\\n📥 Downloading: {osz_file.name}\")\n",
|
| 398 |
+
" files.download(str(osz_file))\n",
|
| 399 |
+
" \n",
|
| 400 |
+
" # Download .osu files\n",
|
| 401 |
+
" elif osu_files:\n",
|
| 402 |
+
" if len(osu_files) == 1:\n",
|
| 403 |
+
" # Single file - download directly\n",
|
| 404 |
+
" osu_file = osu_files[0]\n",
|
| 405 |
+
" print(f\"\\n📥 Downloading: {osu_file.name}\")\n",
|
| 406 |
+
" files.download(str(osu_file))\n",
|
| 407 |
+
" else:\n",
|
| 408 |
+
" # Multiple files - create zip\n",
|
| 409 |
+
" zip_name = f'beatheritage_{gamemode.lower()}_{timestamp}.zip'\n",
|
| 410 |
+
" zip_path = f'/content/{zip_name}'\n",
|
| 411 |
+
" \n",
|
| 412 |
+
" with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:\n",
|
| 413 |
+
" for file in output_files:\n",
|
| 414 |
+
" zipf.write(file, file.name)\n",
|
| 415 |
+
" print(f\" • Added: {file.name}\")\n",
|
| 416 |
+
" \n",
|
| 417 |
+
" print(f\"\\nDownloading: {zip_name}\")\n",
|
| 418 |
+
" files.download(zip_path)\n",
|
| 419 |
+
" \n",
|
| 420 |
+
" # Also handle other files\n",
|
| 421 |
+
" other_files = [f for f in output_files if f.suffix not in ['.osz', '.osu']]\n",
|
| 422 |
+
" if other_files:\n",
|
| 423 |
+
" print(\"\\nAdditional files generated:\")\n",
|
| 424 |
+
" for file in other_files:\n",
|
| 425 |
+
" print(f\" • {file.name}\")\n",
|
| 426 |
+
" \n",
|
| 427 |
+
" print(\"\\nDownload complete!\")\n",
|
| 428 |
+
" print(\"\\nTips:\")\n",
|
| 429 |
+
" print(\"• .osz files can be opened directly in osu!\")\n",
|
| 430 |
+
" print(\"• .osu files should be placed in your Songs folder\")\n",
|
| 431 |
+
" print(\"• Press F5 in osu! to refresh after adding files\")\n",
|
| 432 |
+
"\n",
|
| 433 |
+
"# Download files\n",
|
| 434 |
+
"download_results()"
|
| 435 |
+
],
|
| 436 |
+
"outputs": [],
|
| 437 |
+
"execution_count": null
|
| 438 |
+
},
|
| 439 |
+
{
|
| 440 |
+
"cell_type": "markdown",
|
| 441 |
+
"metadata": {},
|
| 442 |
+
"source": [
|
| 443 |
+
"---\n",
|
| 444 |
+
"\n",
|
| 445 |
+
"## Additional Information\n",
|
| 446 |
+
"\n",
|
| 447 |
+
"### Tips for Best Results:\n",
|
| 448 |
+
"- **Audio Quality**: Use high-quality audio files (320kbps MP3 or FLAC)\n",
|
| 449 |
+
"- **Difficulty Matching**: Match the difficulty rating to song intensity\n",
|
| 450 |
+
"- **Style Descriptors**: Choose descriptors that match the music genre\n",
|
| 451 |
+
"- **Variable BPM**: Enable `super_timing` for songs with tempo changes\n",
|
| 452 |
+
"\n",
|
| 453 |
+
"### Troubleshooting:\n",
|
| 454 |
+
"\n",
|
| 455 |
+
"**Out of Memory:**\n",
|
| 456 |
+
"- Restart runtime to clear GPU memory\n",
|
| 457 |
+
"- Use shorter songs or segments\n",
|
| 458 |
+
"- Reduce cfg_scale value\n",
|
| 459 |
+
"\n",
|
| 460 |
+
"**Poor Quality Output:**\n",
|
| 461 |
+
"- Lower temperature (0.7-0.8) for stability\n",
|
| 462 |
+
"- Increase cfg_scale (10-15) for stronger guidance\n",
|
| 463 |
+
"- Use more specific descriptors\n",
|
| 464 |
+
"\n",
|
| 465 |
+
"**Generation Errors:**\n",
|
| 466 |
+
"- Ensure audio file has no special characters\n",
|
| 467 |
+
"- Check GPU is enabled in runtime\n",
|
| 468 |
+
"- Try different model versions\n",
|
| 469 |
+
"\n",
|
| 470 |
+
"### Resources:\n",
|
| 471 |
+
"- [GitHub Repository](https://github.com/hongminh54/BeatHeritage)\n",
|
| 472 |
+
"- [Documentation](https://github.com/hongminh54/BeatHeritage/blob/main/README.md)\n",
|
| 473 |
+
"\n",
|
| 474 |
+
"### License & Credits:\n",
|
| 475 |
+
"- BeatHeritage V1 by hongminh54\n",
|
| 476 |
+
"- Based on Mapperatorinator by OliBomby\n",
|
| 477 |
+
"- Please credit AI usage in your beatmap descriptions\n",
|
| 478 |
+
"\n",
|
| 479 |
+
"---"
|
| 480 |
+
]
|
| 481 |
+
}
|
| 482 |
+
],
|
| 483 |
+
"metadata": {
|
| 484 |
+
"kernelspec": {
|
| 485 |
+
"display_name": "Python 3",
|
| 486 |
+
"language": "python",
|
| 487 |
+
"name": "python3"
|
| 488 |
+
},
|
| 489 |
+
"language_info": {
|
| 490 |
+
"codemirror_mode": {
|
| 491 |
+
"name": "ipython",
|
| 492 |
+
"version": 3
|
| 493 |
+
},
|
| 494 |
+
"file_extension": ".py",
|
| 495 |
+
"mimetype": "text/x-python",
|
| 496 |
+
"name": "python",
|
| 497 |
+
"nbconvert_exporter": "python",
|
| 498 |
+
"pygments_lexer": "ipython3",
|
| 499 |
+
"version": "3.10.0"
|
| 500 |
+
},
|
| 501 |
+
"colab": {
|
| 502 |
+
"provenance": [],
|
| 503 |
+
"gpuType": "T4",
|
| 504 |
+
"collapsed_sections": []
|
| 505 |
+
},
|
| 506 |
+
"accelerator": "GPU"
|
| 507 |
+
},
|
| 508 |
+
"nbformat": 4,
|
| 509 |
+
"nbformat_minor": 4
|
| 510 |
+
}
|
colab/classifier_classify.ipynb
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"metadata": {},
|
| 5 |
+
"cell_type": "markdown",
|
| 6 |
+
"source": [
|
| 7 |
+
"<a href=\"https://colab.research.google.com/github/OliBomby/Mapperatorinator/blob/main/colab/classifier_classify.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"# Beatmap Mapper Classification\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"This notebook is an interactive demo of an osu! beatmap mapper classification model created by OliBomby. This model is capable of predicting which osu! standard ranked mapper mapped any given beatmap by looking at the style. You can use this on your own maps to see which mapper you are most similar to.\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"### Instructions for running:\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"* __Execute each cell in order__. Press ▶️ on the left of each cell to execute the cell.\n",
|
| 16 |
+
"* __Setup Environment__: run the first cell to clone the repository and install the required dependencies. You only need to run this cell once per session.\n",
|
| 17 |
+
"* __Upload Audio__: choose a .mp3 or .ogg file from your computer.\n",
|
| 18 |
+
"* __Upload Beatmap__: choose a .osu file from your computer.\n",
|
| 19 |
+
"* __Configure__: choose the time of the segment which the classifier should classify.\n",
|
| 20 |
+
"* Classify the beatmap using the __Classify Beatmap__ cell.\n"
|
| 21 |
+
],
|
| 22 |
+
"id": "3c19902455e25588"
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "code",
|
| 26 |
+
"id": "initial_id",
|
| 27 |
+
"metadata": {
|
| 28 |
+
"collapsed": true
|
| 29 |
+
},
|
| 30 |
+
"source": [
|
| 31 |
+
"#@title Setup Environment { display-mode: \"form\" }\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"!git clone https://github.com/OliBomby/Mapperatorinator.git\n",
|
| 34 |
+
"%cd Mapperatorinator\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"!pip install hydra-core lightning nnaudio\n",
|
| 37 |
+
"!pip install slider git+https://github.com/OliBomby/slider.git@gedagedigedagedaoh\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"from google.colab import files\n",
|
| 40 |
+
"from hydra import compose, initialize_config_dir\n",
|
| 41 |
+
"from classifier.classify import main\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"input_audio = \"\"\n",
|
| 44 |
+
"input_beatmap = \"\""
|
| 45 |
+
],
|
| 46 |
+
"outputs": [],
|
| 47 |
+
"execution_count": null
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"metadata": {},
|
| 51 |
+
"cell_type": "code",
|
| 52 |
+
"source": [
|
| 53 |
+
"#@title Upload Audio { display-mode: \"form\" }\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"def upload_audio():\n",
|
| 56 |
+
" data = list(files.upload().keys())\n",
|
| 57 |
+
" if len(data) > 1:\n",
|
| 58 |
+
" print('Multiple files uploaded; using only one.')\n",
|
| 59 |
+
" return data[0]\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"input_audio = upload_audio()"
|
| 62 |
+
],
|
| 63 |
+
"id": "624a60c5777279e7",
|
| 64 |
+
"outputs": [],
|
| 65 |
+
"execution_count": null
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"metadata": {},
|
| 69 |
+
"cell_type": "code",
|
| 70 |
+
"source": [
|
| 71 |
+
"#@title Upload Beatmap { display-mode: \"form\" }\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"def upload_beatmap():\n",
|
| 74 |
+
" data = list(files.upload().keys())\n",
|
| 75 |
+
" if len(data) > 1:\n",
|
| 76 |
+
" print('Multiple files uploaded; using only one.')\n",
|
| 77 |
+
" return data[0]\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"input_beatmap = upload_beatmap()"
|
| 80 |
+
],
|
| 81 |
+
"id": "63884394491f6664",
|
| 82 |
+
"outputs": [],
|
| 83 |
+
"execution_count": null
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"metadata": {},
|
| 87 |
+
"cell_type": "code",
|
| 88 |
+
"source": [
|
| 89 |
+
"#@title Configure and Classify Beatmap { display-mode: \"form\" }\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"# @markdown #### Input the start time in seconds of the segment to classify.\n",
|
| 92 |
+
"time = 5 # @param {type:\"number\"}\n",
|
| 93 |
+
" \n",
|
| 94 |
+
"# Create config\n",
|
| 95 |
+
"with initialize_config_dir(version_base=\"1.1\", config_dir=\"/content/Mapperatorinator/classifier/configs\"):\n",
|
| 96 |
+
" conf = compose(config_name=\"inference\")\n",
|
| 97 |
+
"\n",
|
| 98 |
+
"# Do inference\n",
|
| 99 |
+
"conf.time = time\n",
|
| 100 |
+
"conf.beatmap_path = input_beatmap\n",
|
| 101 |
+
"conf.audio_path = input_audio\n",
|
| 102 |
+
"conf.mappers_path = \"./datasets/beatmap_users.json\"\n",
|
| 103 |
+
"\n",
|
| 104 |
+
"main(conf)\n"
|
| 105 |
+
],
|
| 106 |
+
"id": "166eb3e5f9398554",
|
| 107 |
+
"outputs": [],
|
| 108 |
+
"execution_count": null
|
| 109 |
+
}
|
| 110 |
+
],
|
| 111 |
+
"metadata": {
|
| 112 |
+
"kernelspec": {
|
| 113 |
+
"display_name": "Python 3",
|
| 114 |
+
"language": "python",
|
| 115 |
+
"name": "python3"
|
| 116 |
+
},
|
| 117 |
+
"accelerator": "GPU",
|
| 118 |
+
"language_info": {
|
| 119 |
+
"codemirror_mode": {
|
| 120 |
+
"name": "ipython",
|
| 121 |
+
"version": 2
|
| 122 |
+
},
|
| 123 |
+
"file_extension": ".py",
|
| 124 |
+
"mimetype": "text/x-python",
|
| 125 |
+
"name": "python",
|
| 126 |
+
"nbconvert_exporter": "python",
|
| 127 |
+
"pygments_lexer": "ipython2",
|
| 128 |
+
"version": "2.7.6"
|
| 129 |
+
}
|
| 130 |
+
},
|
| 131 |
+
"nbformat": 4,
|
| 132 |
+
"nbformat_minor": 5
|
| 133 |
+
}
|
colab/mai_mod_inference.ipynb
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"metadata": {},
|
| 5 |
+
"cell_type": "markdown",
|
| 6 |
+
"source": [
|
| 7 |
+
"<a href=\"https://colab.research.google.com/github/OliBomby/Mapperatorinator/blob/main/colab/mai_mod_inference.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"# Beatmap Modding with MaiMod\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"This notebook is an interactive demo of an AI-driven osu! Beatmap Modding Tool created by OliBomby. This model is capable of finding various faults and inconsistencies in beatmaps which other automated modding tools can not detect. Run this tool on your beatmaps to get suggestions on how to improve them.\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"### Instructions for running:\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"* Make sure to use a GPU runtime, click: __Runtime >> Change Runtime Type >> GPU__\n",
|
| 16 |
+
"* __Execute each cell in order__. Press ▶️ on the left of each cell to execute the cell.\n",
|
| 17 |
+
"* __Setup Environment__: run the first cell to clone the repository and install the required dependencies. You only need to run this cell once per session.\n",
|
| 18 |
+
"* __Upload Audio__: choose the beatmap song .mp3 or .ogg file from your computer. You can find these files in stable by using File > Open Song Folder, or in lazer by using File > Edit Externally.\n",
|
| 19 |
+
"* __Upload Beatmap__: choose the beatmap .osu file from your computer.\n",
|
| 20 |
+
"* __Generate Suggestions__ to generate suggestions for your uploaded beatmap.\n"
|
| 21 |
+
],
|
| 22 |
+
"id": "3c19902455e25588"
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "code",
|
| 26 |
+
"id": "initial_id",
|
| 27 |
+
"metadata": {
|
| 28 |
+
"collapsed": true
|
| 29 |
+
},
|
| 30 |
+
"source": [
|
| 31 |
+
"#@title Setup Environment { display-mode: \"form\" }\n",
|
| 32 |
+
"#@markdown Run this cell to clone the repository and install the required dependencies. You only need to run this cell once per session.\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"!git clone https://github.com/OliBomby/Mapperatorinator.git\n",
|
| 35 |
+
"%cd Mapperatorinator\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"!pip install transformers==4.53.3\n",
|
| 38 |
+
"!pip install hydra-core\n",
|
| 39 |
+
"!pip install slider git+https://github.com/OliBomby/slider.git@gedagedigedagedaoh\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"import os\n",
|
| 42 |
+
"from google.colab import files\n",
|
| 43 |
+
"from mai_mod import main\n",
|
| 44 |
+
"from hydra import compose, initialize_config_dir\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"input_audio = \"\"\n",
|
| 47 |
+
"input_beatmap = \"\""
|
| 48 |
+
],
|
| 49 |
+
"outputs": [],
|
| 50 |
+
"execution_count": null
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"metadata": {},
|
| 54 |
+
"cell_type": "code",
|
| 55 |
+
"source": [
|
| 56 |
+
"#@title Upload Audio { display-mode: \"form\" }\n",
|
| 57 |
+
"#@markdown Run this cell to upload the song of the beatmap that you want to mod. Please upload a .mp3 or .ogg file. You can find these files in stable by using File > Open Song Folder, or in lazer by using File > Edit Externally.\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"def upload_audio():\n",
|
| 60 |
+
" data = list(files.upload().keys())\n",
|
| 61 |
+
" if len(data) > 1:\n",
|
| 62 |
+
" print('Multiple files uploaded; using only one.')\n",
|
| 63 |
+
" file = data[0]\n",
|
| 64 |
+
" if not file.endswith('.mp3') and not file.endswith('.ogg'):\n",
|
| 65 |
+
" print('Invalid file format. Please upload a .mp3 or .ogg file.')\n",
|
| 66 |
+
" return \"\"\n",
|
| 67 |
+
" return data[0]\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"input_audio = upload_audio()"
|
| 70 |
+
],
|
| 71 |
+
"id": "624a60c5777279e7",
|
| 72 |
+
"outputs": [],
|
| 73 |
+
"execution_count": null
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"metadata": {},
|
| 77 |
+
"cell_type": "code",
|
| 78 |
+
"source": [
|
| 79 |
+
"#@title Upload Beatmap { display-mode: \"form\" }\n",
|
| 80 |
+
"#@markdown Run this cell to upload the beatmap **.osu** file of the beatmap that you want to mod. You can find these files in stable by using File > Open Song Folder, or in lazer by using File > Edit Externally.\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"def upload_beatmap():\n",
|
| 83 |
+
" data = list(files.upload().keys())\n",
|
| 84 |
+
" if len(data) > 1:\n",
|
| 85 |
+
" print('Multiple files uploaded; using only one.')\n",
|
| 86 |
+
" file = data[0]\n",
|
| 87 |
+
" if not file.endswith('.osu'):\n",
|
| 88 |
+
" print('Invalid file format. Please upload a .osu file.\\nIn stable you can find the .osu file in the song folder (File > Open Song Folder).\\nIn lazer you can find the .osu file by using File > Edit Externally.')\n",
|
| 89 |
+
" return \"\"\n",
|
| 90 |
+
" return file\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"input_beatmap = upload_beatmap()"
|
| 93 |
+
],
|
| 94 |
+
"id": "63884394491f6664",
|
| 95 |
+
"outputs": [],
|
| 96 |
+
"execution_count": null
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"metadata": {},
|
| 100 |
+
"cell_type": "code",
|
| 101 |
+
"source": [
|
| 102 |
+
"#@title Generate Suggestions { display-mode: \"form\" }\n",
|
| 103 |
+
"#@markdown Run this cell to generate suggestions for your uploaded beatmap. The suggestions will be printed in the output.\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"# Validate stuff\n",
|
| 106 |
+
"assert os.path.exists(input_beatmap), \"Please upload a beatmap.\"\n",
|
| 107 |
+
"assert os.path.exists(input_audio), \"Please upload an audio file.\"\n",
|
| 108 |
+
" \n",
|
| 109 |
+
"# Create config\n",
|
| 110 |
+
"config = \"mai_mod\"\n",
|
| 111 |
+
"with initialize_config_dir(version_base=\"1.1\", config_dir=\"/content/Mapperatorinator/configs\"):\n",
|
| 112 |
+
" conf = compose(config_name=config)\n",
|
| 113 |
+
"\n",
|
| 114 |
+
"# Do inference\n",
|
| 115 |
+
"conf.audio_path = input_audio\n",
|
| 116 |
+
"conf.beatmap_path = input_beatmap\n",
|
| 117 |
+
"conf.precision = \"fp32\" # For some reason AMP causes OOM in Colab\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"main(conf)"
|
| 120 |
+
],
|
| 121 |
+
"id": "166eb3e5f9398554",
|
| 122 |
+
"outputs": [],
|
| 123 |
+
"execution_count": null
|
| 124 |
+
}
|
| 125 |
+
],
|
| 126 |
+
"metadata": {
|
| 127 |
+
"kernelspec": {
|
| 128 |
+
"display_name": "Python 3",
|
| 129 |
+
"language": "python",
|
| 130 |
+
"name": "python3"
|
| 131 |
+
},
|
| 132 |
+
"accelerator": "GPU",
|
| 133 |
+
"language_info": {
|
| 134 |
+
"codemirror_mode": {
|
| 135 |
+
"name": "ipython",
|
| 136 |
+
"version": 2
|
| 137 |
+
},
|
| 138 |
+
"file_extension": ".py",
|
| 139 |
+
"mimetype": "text/x-python",
|
| 140 |
+
"name": "python",
|
| 141 |
+
"nbconvert_exporter": "python",
|
| 142 |
+
"pygments_lexer": "ipython2",
|
| 143 |
+
"version": "2.7.6"
|
| 144 |
+
}
|
| 145 |
+
},
|
| 146 |
+
"nbformat": 4,
|
| 147 |
+
"nbformat_minor": 5
|
| 148 |
+
}
|
colab/mapperatorinator_inference.ipynb
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"metadata": {},
|
| 5 |
+
"cell_type": "markdown",
|
| 6 |
+
"source": [
|
| 7 |
+
"<a href=\"https://colab.research.google.com/github/hongminh54/BeatHeritage/blob/main/colab/mapperatorinator_inference.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"# Beatmap Generation with Mapperatorinator\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"This notebook is an interactive demo of an osu! beatmap generation model created by OliBomby. This model is capable of generating hit objects, hitsounds, timing, kiai times, and SVs for all 4 gamemodes. You can upload a beatmap to give to the model as additional context or remap parts of the beatmap.\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"### Instructions for running:\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"* Read and accept the rules regarding using this tool by clicking the checkbox.\n",
|
| 16 |
+
"* Make sure to use a GPU runtime, click: __Runtime >> Change Runtime Type >> GPU__\n",
|
| 17 |
+
"* __Execute each cell in order__. Press ▶️ on the left of each cell to execute the cell.\n",
|
| 18 |
+
"* __Setup Environment__: run the first cell to clone the repository and install the required dependencies. You only need to run this cell once per session.\n",
|
| 19 |
+
"* __Upload Audio__: choose a .mp3 or .ogg file from your computer.\n",
|
| 20 |
+
"* __Upload Beatmap__: optionally choose a beatmap .osu file from your computer. You can find these files in stable by using File > Open Song Folder, or in lazer by using File > Edit Externally.\n",
|
| 21 |
+
"* __Configure__: choose your generation parameters to control the style of the generated beatmap.\n",
|
| 22 |
+
"* Generate the beatmap using the __Generate Beatmap__ cell. (it may take a few minutes depending on the length of the song)\n"
|
| 23 |
+
],
|
| 24 |
+
"id": "3c19902455e25588"
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"cell_type": "code",
|
| 28 |
+
"id": "initial_id",
|
| 29 |
+
"metadata": {
|
| 30 |
+
"collapsed": true
|
| 31 |
+
},
|
| 32 |
+
"source": [
|
| 33 |
+
"#@title Setup Environment { display-mode: \"form\" }\n",
|
| 34 |
+
"#@markdown ### Use this tool responsibly. Always disclose the use of AI in your beatmaps. Accept the rules and run this cell.\n",
|
| 35 |
+
"i_accept_the_rules = False # @param {type:\"boolean\"}\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"assert i_accept_the_rules, \"Read and accept the rules first!\"\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"!git clone https://github.com/hongminh54/BeatHeritage.git\n",
|
| 40 |
+
"%cd Mapperatorinator\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"!pip install transformers==4.53.3\n",
|
| 43 |
+
"!pip install hydra-core nnaudio\n",
|
| 44 |
+
"!pip install slider git+https://github.com/OliBomby/slider.git@gedagedigedagedaoh\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"from google.colab import files\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"import os\n",
|
| 49 |
+
"from hydra import compose, initialize_config_dir\n",
|
| 50 |
+
"from osuT5.osuT5.event import ContextType\n",
|
| 51 |
+
"from inference import main\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"output_path = \"output\"\n",
|
| 54 |
+
"input_audio = \"\"\n",
|
| 55 |
+
"input_beatmap = \"\""
|
| 56 |
+
],
|
| 57 |
+
"outputs": [],
|
| 58 |
+
"execution_count": null
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"cell_type": "code",
|
| 63 |
+
"source": [
|
| 64 |
+
"#@title Upload Audio { display-mode: \"form\" }\n",
|
| 65 |
+
"#@markdown Run this cell to upload audio. This is the song to generate a beatmap for. Please upload a .mp3 or .ogg file.\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"def upload_audio():\n",
|
| 68 |
+
" data = list(files.upload().keys())\n",
|
| 69 |
+
" if len(data) > 1:\n",
|
| 70 |
+
" print('Multiple files uploaded; using only one.')\n",
|
| 71 |
+
" file = data[0]\n",
|
| 72 |
+
" if not file.endswith('.mp3') and not file.endswith('.ogg'):\n",
|
| 73 |
+
" print('Invalid file format. Please upload a .mp3 or .ogg file.')\n",
|
| 74 |
+
" return \"\"\n",
|
| 75 |
+
" return data[0]\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"input_audio = upload_audio()"
|
| 78 |
+
],
|
| 79 |
+
"id": "624a60c5777279e7",
|
| 80 |
+
"outputs": [],
|
| 81 |
+
"execution_count": null
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"metadata": {},
|
| 85 |
+
"cell_type": "code",
|
| 86 |
+
"source": [
|
| 87 |
+
"#@title (Optional) Upload Beatmap { display-mode: \"form\" }\n",
|
| 88 |
+
"#@markdown This step is required if you want to use `in_context` or `add_to_beatmap` to provide additional info to the model.\n",
|
| 89 |
+
"#@markdown It will also fill in any missing metadata and unknown values in the configuration using info of the reference beatmap.\n",
|
| 90 |
+
"#@markdown Please upload a **.osu** file. You can find the .osu file in the song folder in stable or by using File > Edit Externally in lazer.\n",
|
| 91 |
+
"use_reference_beatmap = False # @param {type:\"boolean\"}\n",
|
| 92 |
+
"\n",
|
| 93 |
+
"def upload_beatmap():\n",
|
| 94 |
+
" data = list(files.upload().keys())\n",
|
| 95 |
+
" if len(data) > 1:\n",
|
| 96 |
+
" print('Multiple files uploaded; using only one.')\n",
|
| 97 |
+
" file = data[0]\n",
|
| 98 |
+
" if not file.endswith('.osu'):\n",
|
| 99 |
+
" print('Invalid file format. Please upload a .osu file.\\nIn stable you can find the .osu file in the song folder (File > Open Song Folder).\\nIn lazer you can find the .osu file by using File > Edit Externally.')\n",
|
| 100 |
+
" return \"\"\n",
|
| 101 |
+
" return file\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"if use_reference_beatmap:\n",
|
| 104 |
+
" input_beatmap = upload_beatmap()\n",
|
| 105 |
+
"else:\n",
|
| 106 |
+
" input_beatmap = \"\""
|
| 107 |
+
],
|
| 108 |
+
"id": "63884394491f6664",
|
| 109 |
+
"outputs": [],
|
| 110 |
+
"execution_count": null
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"metadata": {},
|
| 114 |
+
"cell_type": "code",
|
| 115 |
+
"source": [
|
| 116 |
+
"#@title Configure and Generate Beatmap { display-mode: \"form\" }\n",
|
| 117 |
+
"\n",
|
| 118 |
+
"#@markdown #### You can input -1 to leave the value unknown.\n",
|
| 119 |
+
"#@markdown ---\n",
|
| 120 |
+
"#@markdown This is the AI model to use. V30 is the most accurate model, but it does not support other gamemodes, year, descriptors, or in_context.\n",
|
| 121 |
+
"model = \"Mapperatorinator V30\" # @param [\"Mapperatorinator V29\", \"Mapperatorinator V30\"]\n",
|
| 122 |
+
"#@markdown This is the game mode to generate a beatmap for.\n",
|
| 123 |
+
"gamemode = \"standard\" # @param [\"standard\", \"taiko\", \"catch the beat\", \"mania\"]\n",
|
| 124 |
+
"#@markdown This is the Star Rating you want your beatmap to be. It might deviate from this number depending on the song intensity and other configuration.\n",
|
| 125 |
+
"difficulty = 5 # @param {type:\"number\"}\n",
|
| 126 |
+
"#@markdown This is the user ID of the ranked mapper to imitate for mapping style. You can find this in the URL of the mapper's profile.\n",
|
| 127 |
+
"mapper_id = -1 # @param {type:\"integer\"}\n",
|
| 128 |
+
"#@markdown This is the year you want the beatmap to be from. It should be in the range of 2007 to 2023.\n",
|
| 129 |
+
"year = 2023 # @param {type:\"integer\"}\n",
|
| 130 |
+
"#@markdown This is whether you want the beatmap to be hitsounded. This works only for mania mode.\n",
|
| 131 |
+
"hitsounded = True # @param {type:\"boolean\"}\n",
|
| 132 |
+
"#@markdown These are the standard difficulty parameters for the beatmap HP, OD, AR, and CS. These are the same as the ones in the editor.\n",
|
| 133 |
+
"hp_drain_rate = 5 # @param {type:\"number\"}\n",
|
| 134 |
+
"circle_size = 4 # @param {type:\"number\"}\n",
|
| 135 |
+
"overall_difficulty = 9 # @param {type:\"number\"}\n",
|
| 136 |
+
"approach_rate = 8 # @param {type:\"number\"}\n",
|
| 137 |
+
"slider_multiplier = 1.4 # @param {type:\"slider\", min:0.4, max:3.6, step:0.1}\n",
|
| 138 |
+
"slider_tick_rate = 1 # @param {type:\"number\"}\n",
|
| 139 |
+
"#@markdown This is the number of keys for the mania beatmap. This works only for mania mode.\n",
|
| 140 |
+
"keycount = 4 # @param {type:\"slider\", min:1, max:18, step:1}\n",
|
| 141 |
+
"#@markdown This is the ratio of hold notes to circles in the beatmap. It should be in the range [0,1]. This works only for mania mode.\n",
|
| 142 |
+
"hold_note_ratio = -1 # @param {type:\"number\"}\n",
|
| 143 |
+
"#@markdown This is the ratio of scroll speed changes to the number of notes. It should be in the range [0,1]. This works only for mania and taiko modes.\n",
|
| 144 |
+
"scroll_speed_ratio = -1 # @param {type:\"number\"}\n",
|
| 145 |
+
"#@markdown These descriptors of the beatmap. Descriptors are used to describe the style of the beatmap. All available descriptors can be found [here](https://osu.ppy.sh/wiki/en/Beatmap/Beatmap_tags).\n",
|
| 146 |
+
"descriptor_1 = '' # @param [\"slider only\", \"circle only\", \"collab\", \"megacollab\", \"marathon\", \"gungathon\", \"multi-song\", \"variable timing\", \"accelerating bpm\", \"time signatures\", \"storyboard\", \"storyboard gimmick\", \"keysounds\", \"download unavailable\", \"custom skin\", \"featured artist\", \"custom song\", \"style\", \"messy\", \"geometric\", \"grid snap\", \"hexgrid\", \"freeform\", \"symmetrical\", \"old-style revival\", \"clean\", \"slidershapes\", \"distance snapped\", \"iNiS-style\", \"avant-garde\", \"perfect stacks\", \"ninja spinners\", \"simple\", \"chaotic\", \"repetition\", \"progression\", \"high contrast\", \"improvisation\", \"playfield usage\", \"playfield constraint\", \"video gimmick\", \"difficulty spike\", \"low sv\", \"high sv\", \"colorhax\", \"tech\", \"slider tech\", \"complex sv\", \"reading\", \"visually dense\", \"overlap reading\", \"alt\", \"jump aim\", \"sharp aim\", \"wide aim\", \"linear aim\", \"aim control\", \"flow aim\", \"precision\", \"finger control\", \"complex snap divisors\", \"bursts\", \"streams\", \"spaced streams\", \"cutstreams\", \"stamina\", \"mapping contest\", \"tournament custom\", \"tag\", \"port\"] {allow-input: true}\n",
|
| 147 |
+
"descriptor_2 = '' # @param [\"slider only\", \"circle only\", \"collab\", \"megacollab\", \"marathon\", \"gungathon\", \"multi-song\", \"variable timing\", \"accelerating bpm\", \"time signatures\", \"storyboard\", \"storyboard gimmick\", \"keysounds\", \"download unavailable\", \"custom skin\", \"featured artist\", \"custom song\", \"style\", \"messy\", \"geometric\", \"grid snap\", \"hexgrid\", \"freeform\", \"symmetrical\", \"old-style revival\", \"clean\", \"slidershapes\", \"distance snapped\", \"iNiS-style\", \"avant-garde\", \"perfect stacks\", \"ninja spinners\", \"simple\", \"chaotic\", \"repetition\", \"progression\", \"high contrast\", \"improvisation\", \"playfield usage\", \"playfield constraint\", \"video gimmick\", \"difficulty spike\", \"low sv\", \"high sv\", \"colorhax\", \"tech\", \"slider tech\", \"complex sv\", \"reading\", \"visually dense\", \"overlap reading\", \"alt\", \"jump aim\", \"sharp aim\", \"wide aim\", \"linear aim\", \"aim control\", \"flow aim\", \"precision\", \"finger control\", \"complex snap divisors\", \"bursts\", \"streams\", \"spaced streams\", \"cutstreams\", \"stamina\", \"mapping contest\", \"tournament custom\", \"tag\", \"port\"] {allow-input: true}\n",
|
| 148 |
+
"descriptor_3 = '' # @param [\"slider only\", \"circle only\", \"collab\", \"megacollab\", \"marathon\", \"gungathon\", \"multi-song\", \"variable timing\", \"accelerating bpm\", \"time signatures\", \"storyboard\", \"storyboard gimmick\", \"keysounds\", \"download unavailable\", \"custom skin\", \"featured artist\", \"custom song\", \"style\", \"messy\", \"geometric\", \"grid snap\", \"hexgrid\", \"freeform\", \"symmetrical\", \"old-style revival\", \"clean\", \"slidershapes\", \"distance snapped\", \"iNiS-style\", \"avant-garde\", \"perfect stacks\", \"ninja spinners\", \"simple\", \"chaotic\", \"repetition\", \"progression\", \"high contrast\", \"improvisation\", \"playfield usage\", \"playfield constraint\", \"video gimmick\", \"difficulty spike\", \"low sv\", \"high sv\", \"colorhax\", \"tech\", \"slider tech\", \"complex sv\", \"reading\", \"visually dense\", \"overlap reading\", \"alt\", \"jump aim\", \"sharp aim\", \"wide aim\", \"linear aim\", \"aim control\", \"flow aim\", \"precision\", \"finger control\", \"complex snap divisors\", \"bursts\", \"streams\", \"spaced streams\", \"cutstreams\", \"stamina\", \"mapping contest\", \"tournament custom\", \"tag\", \"port\"] {allow-input: true}\n",
|
| 149 |
+
"#@markdown These are negative descriptors of the beatmap. Negative descriptors are used to describe what the beatmap should not have. These work only when `cfg_scale` is greater than 1.\n",
|
| 150 |
+
"negative_descriptor_1 = '' # @param [\"slider only\", \"circle only\", \"collab\", \"megacollab\", \"marathon\", \"gungathon\", \"multi-song\", \"variable timing\", \"accelerating bpm\", \"time signatures\", \"storyboard\", \"storyboard gimmick\", \"keysounds\", \"download unavailable\", \"custom skin\", \"featured artist\", \"custom song\", \"style\", \"messy\", \"geometric\", \"grid snap\", \"hexgrid\", \"freeform\", \"symmetrical\", \"old-style revival\", \"clean\", \"slidershapes\", \"distance snapped\", \"iNiS-style\", \"avant-garde\", \"perfect stacks\", \"ninja spinners\", \"simple\", \"chaotic\", \"repetition\", \"progression\", \"high contrast\", \"improvisation\", \"playfield usage\", \"playfield constraint\", \"video gimmick\", \"difficulty spike\", \"low sv\", \"high sv\", \"colorhax\", \"tech\", \"slider tech\", \"complex sv\", \"reading\", \"visually dense\", \"overlap reading\", \"alt\", \"jump aim\", \"sharp aim\", \"wide aim\", \"linear aim\", \"aim control\", \"flow aim\", \"precision\", \"finger control\", \"complex snap divisors\", \"bursts\", \"streams\", \"spaced streams\", \"cutstreams\", \"stamina\", \"mapping contest\", \"tournament custom\", \"tag\", \"port\"] {allow-input: true}\n",
|
| 151 |
+
"negative_descriptor_2 = '' # @param [\"slider only\", \"circle only\", \"collab\", \"megacollab\", \"marathon\", \"gungathon\", \"multi-song\", \"variable timing\", \"accelerating bpm\", \"time signatures\", \"storyboard\", \"storyboard gimmick\", \"keysounds\", \"download unavailable\", \"custom skin\", \"featured artist\", \"custom song\", \"style\", \"messy\", \"geometric\", \"grid snap\", \"hexgrid\", \"freeform\", \"symmetrical\", \"old-style revival\", \"clean\", \"slidershapes\", \"distance snapped\", \"iNiS-style\", \"avant-garde\", \"perfect stacks\", \"ninja spinners\", \"simple\", \"chaotic\", \"repetition\", \"progression\", \"high contrast\", \"improvisation\", \"playfield usage\", \"playfield constraint\", \"video gimmick\", \"difficulty spike\", \"low sv\", \"high sv\", \"colorhax\", \"tech\", \"slider tech\", \"complex sv\", \"reading\", \"visually dense\", \"overlap reading\", \"alt\", \"jump aim\", \"sharp aim\", \"wide aim\", \"linear aim\", \"aim control\", \"flow aim\", \"precision\", \"finger control\", \"complex snap divisors\", \"bursts\", \"streams\", \"spaced streams\", \"cutstreams\", \"stamina\", \"mapping contest\", \"tournament custom\", \"tag\", \"port\"] {allow-input: true}\n",
|
| 152 |
+
"negative_descriptor_3 = '' # @param [\"slider only\", \"circle only\", \"collab\", \"megacollab\", \"marathon\", \"gungathon\", \"multi-song\", \"variable timing\", \"accelerating bpm\", \"time signatures\", \"storyboard\", \"storyboard gimmick\", \"keysounds\", \"download unavailable\", \"custom skin\", \"featured artist\", \"custom song\", \"style\", \"messy\", \"geometric\", \"grid snap\", \"hexgrid\", \"freeform\", \"symmetrical\", \"old-style revival\", \"clean\", \"slidershapes\", \"distance snapped\", \"iNiS-style\", \"avant-garde\", \"perfect stacks\", \"ninja spinners\", \"simple\", \"chaotic\", \"repetition\", \"progression\", \"high contrast\", \"improvisation\", \"playfield usage\", \"playfield constraint\", \"video gimmick\", \"difficulty spike\", \"low sv\", \"high sv\", \"colorhax\", \"tech\", \"slider tech\", \"complex sv\", \"reading\", \"visually dense\", \"overlap reading\", \"alt\", \"jump aim\", \"sharp aim\", \"wide aim\", \"linear aim\", \"aim control\", \"flow aim\", \"precision\", \"finger control\", \"complex snap divisors\", \"bursts\", \"streams\", \"spaced streams\", \"cutstreams\", \"stamina\", \"mapping contest\", \"tournament custom\", \"tag\", \"port\"] {allow-input: true}\n",
|
| 153 |
+
"#@markdown ---\n",
|
| 154 |
+
"#@markdown If true, the generated beatmap will be exported as a .osz file. Otherwise, it will be exported as a .osu file.\n",
|
| 155 |
+
"export_osz = False # @param {type:\"boolean\"}\n",
|
| 156 |
+
"#@markdown If true, the generated beatmap will be added to the reference beatmap and the reference beatmap will be modified instead of creating a new beatmap. It will also continue any hit objects before the start time in the reference beatmap.\n",
|
| 157 |
+
"add_to_beatmap = False # @param {type:\"boolean\"}\n",
|
| 158 |
+
"#@markdown This is the start time of the beatmap in milliseconds. Use this to constrain the generation to a specific part of the song.\n",
|
| 159 |
+
"start_time = -1 # @param {type:\"integer\"}\n",
|
| 160 |
+
"#@markdown This is the end time of the beatmap in milliseconds. Use this to constrain the generation to a specific part of the song.\n",
|
| 161 |
+
"end_time = -1 # @param {type:\"integer\"}\n",
|
| 162 |
+
"#@markdown This is which additional information to give to the model:\n",
|
| 163 |
+
"#@markdown - TIMING: Give timing points to the model. This will skip the timing point generation step.\n",
|
| 164 |
+
"#@markdown - KIAI: Give kiai times to the model. This will skip the kiai time generation step.\n",
|
| 165 |
+
"#@markdown - MAP: Give hit objects to the model. This will skip the hit object generation step.\n",
|
| 166 |
+
"#@markdown - GD: Give hit objects of another difficulty in the same mapset to the model (can be a different game mode). It will improve the rhythm accuracy and consistency of the generated beatmap without copying the reference beatmap.\n",
|
| 167 |
+
"#@markdown - NO_HS: Give hit objects without hitsounds to the model. This will copy the hit objects of the reference beatmap and only add hitsounds to them.\n",
|
| 168 |
+
"in_context = \"[NONE]\" # @param [\"[NONE]\", \"[TIMING]\", \"[TIMING,KIAI]\", \"[TIMING,KIAI,MAP]\", \"[GD,TIMING,KIAI]\", \"[NO_HS,TIMING,KIAI]\"]\n",
|
| 169 |
+
"#@markdown This is the output type of the beatmap. You can choose to either generate everything or only generate timing points.\n",
|
| 170 |
+
"output_type = \"[MAP]\" # @param [\"[MAP]\", \"[TIMING,KIAI,MAP,SV]\", \"[TIMING]\"]\n",
|
| 171 |
+
"#@markdown This is the scale of the classifier-free guidance. A higher scale will make the model more likely to follow the descriptors and mapper style. A high `cfg_scale` or certain combinations of settings can produce unexpected results, so use it with caution. \n",
|
| 172 |
+
"cfg_scale = 1 # @param {type:\"slider\", min:1, max:5, step:0.1}\n",
|
| 173 |
+
"#@markdown This is the temperature of the sampling. A lower temperature will make the model more conservative and less creative. I only recommend lowering this slightly or when using `add_to_beatmap` and generating small sections.\n",
|
| 174 |
+
"temperature = 1 # @param {type:\"slider\", min:0, max:1, step:0.01}\n",
|
| 175 |
+
"#@markdown This is the random seed. Change this to sample a different beatmap with the same settings.\n",
|
| 176 |
+
"seed = -1 # @param {type:\"integer\"}\n",
|
| 177 |
+
"#@markdown ---\n",
|
| 178 |
+
"#@markdown If true, uses a slow and accurate timing generator. This will make the generation slower, but the timing will be more accurate.\n",
|
| 179 |
+
"#@markdown This is the leniency of the normal timing generator. It will allow the timing ticks to deviate from the real timing by this many milliseconds. A higher value will result in less timing points.\n",
|
| 180 |
+
"timing_leniency = 20 # @param {type:\"slider\", min:0, max:100, step:1}\n",
|
| 181 |
+
"super_timing = False # @param {type:\"boolean\"}\n",
|
| 182 |
+
"#@markdown This is the number of beams for beam search for the super timing generator. Higher values will result in slightly more accurate timing at the cost of speed. \n",
|
| 183 |
+
"timer_num_beams = 2 # @param {type:\"slider\", min:1, max:16, step:1}\n",
|
| 184 |
+
"#@markdown This is the number of iterations for the super timing generator. Higher values will result in slightly more accurate timing at the cost of speed.\n",
|
| 185 |
+
"timer_iterations = 20 # @param {type:\"slider\", min:1, max:100, step:1}\n",
|
| 186 |
+
"#@markdown This is the certainty threshold requirement for BPM changes in the super timing generator. Higher values will result in less BPM changes.\n",
|
| 187 |
+
"timer_bpm_threshold = 0.1 # @param {type:\"slider\", min:0, max:1, step:0.1}\n",
|
| 188 |
+
"#@markdown ---\n",
|
| 189 |
+
"\n",
|
| 190 |
+
"# Get actual parameters\n",
|
| 191 |
+
"a_config = model.split(' ')[-1].lower()\n",
|
| 192 |
+
"a_gamemode = [\"standard\", \"taiko\", \"catch the beat\", \"mania\"].index(gamemode)\n",
|
| 193 |
+
"a_difficulty = None if difficulty == -1 else difficulty\n",
|
| 194 |
+
"a_mapper_id = None if mapper_id == -1 else mapper_id\n",
|
| 195 |
+
"a_year = None if year == -1 else year\n",
|
| 196 |
+
"a_hp_drain_rate = None if hp_drain_rate == -1 else hp_drain_rate\n",
|
| 197 |
+
"a_circle_size = None if circle_size == -1 else circle_size\n",
|
| 198 |
+
"a_overall_difficulty = None if overall_difficulty == -1 else overall_difficulty\n",
|
| 199 |
+
"a_approach_rate = None if approach_rate == -1 else approach_rate\n",
|
| 200 |
+
"a_slider_multiplier = None if slider_multiplier == -1 else slider_multiplier\n",
|
| 201 |
+
"a_slider_tick_rate = None if slider_tick_rate == -1 else slider_tick_rate\n",
|
| 202 |
+
"a_hold_note_ratio = None if hold_note_ratio == -1 else hold_note_ratio\n",
|
| 203 |
+
"a_scroll_speed_ratio = None if scroll_speed_ratio == -1 else scroll_speed_ratio\n",
|
| 204 |
+
"descriptors = [d for d in [descriptor_1, descriptor_2, descriptor_3] if d != '']\n",
|
| 205 |
+
"negative_descriptors = [d for d in [negative_descriptor_1, negative_descriptor_2, negative_descriptor_3] if d != '']\n",
|
| 206 |
+
"\n",
|
| 207 |
+
"a_start_time = None if start_time == -1 else start_time\n",
|
| 208 |
+
"a_end_time = None if end_time == -1 else end_time\n",
|
| 209 |
+
"a_in_context = [ContextType(c.lower()) for c in in_context[1:-1].split(',')]\n",
|
| 210 |
+
"a_output_type = [ContextType(c.lower()) for c in output_type[1:-1].split(',')]\n",
|
| 211 |
+
"a_seed = None if seed == -1 else seed\n",
|
| 212 |
+
"\n",
|
| 213 |
+
"# Validate stuff\n",
|
| 214 |
+
"if any(c in a_in_context for c in [ContextType.TIMING, ContextType.KIAI, ContextType.MAP, ContextType.SV, ContextType.GD, ContextType.NO_HS]) or add_to_beatmap:\n",
|
| 215 |
+
" assert os.path.exists(input_beatmap), \"Please upload a reference beatmap.\"\n",
|
| 216 |
+
"assert os.path.exists(input_audio), \"Please upload an audio file.\"\n",
|
| 217 |
+
"if a_config == \"v30\":\n",
|
| 218 |
+
" assert a_gamemode == 0, \"V30 only supports standard mode.\"\n",
|
| 219 |
+
" if any(c in a_in_context for c in [ContextType.KIAI, ContextType.MAP, ContextType.SV]):\n",
|
| 220 |
+
" print(\"WARNING: V30 does not support KIAI, MAP, or SV in_context, ignoring.\")\n",
|
| 221 |
+
" if output_type != \"[MAP]\":\n",
|
| 222 |
+
" print(\"WARNING: V30 only supports [MAP] output type, setting output type to [MAP].\")\n",
|
| 223 |
+
" a_output_type = [ContextType.MAP]\n",
|
| 224 |
+
" if len(descriptors) != 0 and len(negative_descriptors) != 0:\n",
|
| 225 |
+
" print(\"WARNING: V30 does not support descriptors or negative descriptors, ignoring.\")\n",
|
| 226 |
+
" if super_timing:\n",
|
| 227 |
+
" print(\"WARNING: V30 does not fully support super timing, generation will be VERY slow.\")\n",
|
| 228 |
+
" \n",
|
| 229 |
+
"# Create config\n",
|
| 230 |
+
"with initialize_config_dir(version_base=\"1.1\", config_dir=\"/content/Mapperatorinator/configs/inference\"):\n",
|
| 231 |
+
" conf = compose(config_name=a_config)\n",
|
| 232 |
+
"\n",
|
| 233 |
+
"# Do inference\n",
|
| 234 |
+
"conf.audio_path = input_audio\n",
|
| 235 |
+
"conf.output_path = output_path\n",
|
| 236 |
+
"conf.beatmap_path = input_beatmap\n",
|
| 237 |
+
"conf.gamemode = a_gamemode\n",
|
| 238 |
+
"conf.difficulty = a_difficulty\n",
|
| 239 |
+
"conf.mapper_id = a_mapper_id\n",
|
| 240 |
+
"conf.year = a_year\n",
|
| 241 |
+
"conf.hitsounded = hitsounded\n",
|
| 242 |
+
"conf.hp_drain_rate = a_hp_drain_rate\n",
|
| 243 |
+
"conf.circle_size = a_circle_size\n",
|
| 244 |
+
"conf.overall_difficulty = a_overall_difficulty\n",
|
| 245 |
+
"conf.approach_rate = a_approach_rate\n",
|
| 246 |
+
"conf.slider_multiplier = a_slider_multiplier\n",
|
| 247 |
+
"conf.slider_tick_rate = a_slider_tick_rate\n",
|
| 248 |
+
"conf.keycount = keycount\n",
|
| 249 |
+
"conf.hold_note_ratio = a_hold_note_ratio\n",
|
| 250 |
+
"conf.scroll_speed_ratio = a_scroll_speed_ratio\n",
|
| 251 |
+
"conf.descriptors = descriptors\n",
|
| 252 |
+
"conf.negative_descriptors = negative_descriptors\n",
|
| 253 |
+
"conf.export_osz = export_osz\n",
|
| 254 |
+
"conf.add_to_beatmap = add_to_beatmap\n",
|
| 255 |
+
"conf.start_time = a_start_time\n",
|
| 256 |
+
"conf.end_time = a_end_time\n",
|
| 257 |
+
"conf.in_context = a_in_context\n",
|
| 258 |
+
"conf.output_type = a_output_type\n",
|
| 259 |
+
"conf.cfg_scale = cfg_scale\n",
|
| 260 |
+
"conf.temperature = temperature\n",
|
| 261 |
+
"conf.seed = a_seed\n",
|
| 262 |
+
"conf.timing_leniency = timing_leniency\n",
|
| 263 |
+
"conf.super_timing = super_timing\n",
|
| 264 |
+
"conf.timer_num_beams = timer_num_beams\n",
|
| 265 |
+
"conf.timer_iterations = timer_iterations\n",
|
| 266 |
+
"conf.timer_bpm_threshold = timer_bpm_threshold\n",
|
| 267 |
+
"\n",
|
| 268 |
+
"_, result_path, osz_path = main(conf)\n",
|
| 269 |
+
"\n",
|
| 270 |
+
"if osz_path is not None:\n",
|
| 271 |
+
" result_path = osz_path\n",
|
| 272 |
+
"\n",
|
| 273 |
+
"if conf.add_to_beatmap:\n",
|
| 274 |
+
" files.download(result_path)\n",
|
| 275 |
+
"else:\n",
|
| 276 |
+
" files.download(result_path)\n"
|
| 277 |
+
],
|
| 278 |
+
"id": "166eb3e5f9398554",
|
| 279 |
+
"outputs": [],
|
| 280 |
+
"execution_count": null
|
| 281 |
+
}
|
| 282 |
+
],
|
| 283 |
+
"metadata": {
|
| 284 |
+
"kernelspec": {
|
| 285 |
+
"display_name": "Python 3",
|
| 286 |
+
"language": "python",
|
| 287 |
+
"name": "python3"
|
| 288 |
+
},
|
| 289 |
+
"accelerator": "GPU",
|
| 290 |
+
"language_info": {
|
| 291 |
+
"codemirror_mode": {
|
| 292 |
+
"name": "ipython",
|
| 293 |
+
"version": 2
|
| 294 |
+
},
|
| 295 |
+
"file_extension": ".py",
|
| 296 |
+
"mimetype": "text/x-python",
|
| 297 |
+
"name": "python",
|
| 298 |
+
"nbconvert_exporter": "python",
|
| 299 |
+
"pygments_lexer": "ipython2",
|
| 300 |
+
"version": "2.7.6"
|
| 301 |
+
}
|
| 302 |
+
},
|
| 303 |
+
"nbformat": 4,
|
| 304 |
+
"nbformat_minor": 5
|
| 305 |
+
}
|
collate_results.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_color_for_value(value, min_val, max_val, lower_is_better=False):
|
| 6 |
+
"""
|
| 7 |
+
Generates an HSL color string from red to green based on a value's
|
| 8 |
+
position between a min and max.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
value (float): The current value.
|
| 12 |
+
min_val (float): The minimum value in the dataset for this metric.
|
| 13 |
+
max_val (float): The maximum value in the dataset for this metric.
|
| 14 |
+
lower_is_better (bool): If True, lower values get greener colors.
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
str: An HSL color string for use in CSS.
|
| 18 |
+
"""
|
| 19 |
+
# Avoid division by zero if all values are the same
|
| 20 |
+
if min_val == max_val:
|
| 21 |
+
return "hsl(120, 70%, 60%)" # Default to green
|
| 22 |
+
|
| 23 |
+
# Normalize the value to a 0-1 range
|
| 24 |
+
normalized = (value - min_val) / (max_val - min_val)
|
| 25 |
+
|
| 26 |
+
if lower_is_better:
|
| 27 |
+
# Invert the scale: 1 (best) -> 0 (worst)
|
| 28 |
+
hue = (1 - normalized) * 120
|
| 29 |
+
else:
|
| 30 |
+
# Standard scale: 0 (worst) -> 1 (best)
|
| 31 |
+
hue = normalized * 120
|
| 32 |
+
|
| 33 |
+
# Return HSL color: hue from 0 (red) to 120 (green), with fixed saturation and lightness
|
| 34 |
+
return f"hsl({hue:.0f}, 70%, 60%)"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def parse_log_files(root_dir):
|
| 38 |
+
"""
|
| 39 |
+
Parses log files in subdirectories to extract metrics and format them
|
| 40 |
+
into an HTML table with colored cells.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
root_dir (str): The path to the main folder containing the model subfolders.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
str: A string containing the formatted HTML table.
|
| 47 |
+
"""
|
| 48 |
+
results = []
|
| 49 |
+
dir_pattern = re.compile(r"inference=(inference_)?([a-zA-Z0-9_-]+)")
|
| 50 |
+
metric_patterns = {
|
| 51 |
+
'FID': re.compile(r"FID: ([\d.]+)"),
|
| 52 |
+
'AR Pr.': re.compile(r"Active Rhythm Precision: ([\d.]+)"),
|
| 53 |
+
'AR Re.': re.compile(r"Active Rhythm Recall: ([\d.]+)"),
|
| 54 |
+
'AR F1': re.compile(r"Active Rhythm F1: ([\d.]+)"),
|
| 55 |
+
'PR Pr.': re.compile(r"Passive Rhythm Precision: ([\d.]+)"),
|
| 56 |
+
'PR Re.': re.compile(r"Passive Rhythm Recall: ([\d.]+)"),
|
| 57 |
+
'PR F1': re.compile(r"Passive Rhythm F1: ([\d.]+)")
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
for dirpath, dirnames, filenames in os.walk(root_dir):
|
| 61 |
+
if dirpath == root_dir:
|
| 62 |
+
for dirname in dirnames:
|
| 63 |
+
dir_match = dir_pattern.match(dirname)
|
| 64 |
+
if not dir_match:
|
| 65 |
+
continue
|
| 66 |
+
model_name = dir_match.group(2)
|
| 67 |
+
log_file_path = os.path.join(dirpath, dirname, 'calc_fid.log')
|
| 68 |
+
|
| 69 |
+
if not os.path.exists(log_file_path):
|
| 70 |
+
print(f"Warning: 'calc_fid.log' not found in {dirname}")
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
latest_metrics = {}
|
| 74 |
+
try:
|
| 75 |
+
with open(log_file_path, 'r') as f:
|
| 76 |
+
for line in f:
|
| 77 |
+
for key, pattern in metric_patterns.items():
|
| 78 |
+
match = pattern.search(line)
|
| 79 |
+
if match:
|
| 80 |
+
latest_metrics[key] = float(match.group(1))
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(f"Error reading {log_file_path}: {e}")
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
if latest_metrics:
|
| 86 |
+
latest_metrics['Model name'] = model_name
|
| 87 |
+
results.append(latest_metrics)
|
| 88 |
+
dirnames[:] = []
|
| 89 |
+
|
| 90 |
+
if not results:
|
| 91 |
+
return "<p>No results found. Check if <code>root_dir</code> is correct and log files exist.</p>"
|
| 92 |
+
|
| 93 |
+
# --- Pre-calculate Min/Max for coloring ---
|
| 94 |
+
headers = ["Model name", "FID", "AR Pr.", "AR Re.", "AR F1", "PR Pr.", "PR Re.", "PR F1"]
|
| 95 |
+
min_max_vals = {}
|
| 96 |
+
for header in headers:
|
| 97 |
+
if header == "Model name":
|
| 98 |
+
continue
|
| 99 |
+
# Get all valid values for the current header
|
| 100 |
+
values = [res.get(header) for res in results if res.get(header) is not None]
|
| 101 |
+
if values:
|
| 102 |
+
min_max_vals[header] = {'min': min(values), 'max': max(values)}
|
| 103 |
+
|
| 104 |
+
# --- Generate HTML Table ---
|
| 105 |
+
html = ["<table>"]
|
| 106 |
+
# Header row
|
| 107 |
+
html.append(" <thead>")
|
| 108 |
+
html.append(" <tr>" + "".join([f"<th>{h}</th>" for h in headers]) + "</tr>")
|
| 109 |
+
html.append(" </thead>")
|
| 110 |
+
|
| 111 |
+
# Data rows
|
| 112 |
+
html.append(" <tbody>")
|
| 113 |
+
for res in sorted(results, key=lambda x: x.get('Model name', '')):
|
| 114 |
+
row_html = " <tr>"
|
| 115 |
+
for header in headers:
|
| 116 |
+
value = res.get(header)
|
| 117 |
+
|
| 118 |
+
if header == 'Model name':
|
| 119 |
+
row_html += f"<td>{res.get('Model name', 'N/A')}</td>"
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
if value is None:
|
| 123 |
+
row_html += "<td>N/A</td>"
|
| 124 |
+
continue
|
| 125 |
+
|
| 126 |
+
# Formatting
|
| 127 |
+
if header == 'FID':
|
| 128 |
+
formatted_value = f"{value:.2f}"
|
| 129 |
+
lower_is_better = True
|
| 130 |
+
else:
|
| 131 |
+
formatted_value = f"{value:.3f}"
|
| 132 |
+
lower_is_better = False
|
| 133 |
+
|
| 134 |
+
# Get color and apply style
|
| 135 |
+
color = get_color_for_value(value, min_max_vals[header]['min'], min_max_vals[header]['max'],
|
| 136 |
+
lower_is_better)
|
| 137 |
+
# Added a light text shadow for better readability on bright colors
|
| 138 |
+
style = f"background-color: {color}; color: black; text-shadow: 0 0 5px white;"
|
| 139 |
+
row_html += f'<td style="{style}">{formatted_value}</td>'
|
| 140 |
+
|
| 141 |
+
row_html += "</tr>"
|
| 142 |
+
html.append(row_html)
|
| 143 |
+
|
| 144 |
+
html.append(" </tbody>")
|
| 145 |
+
html.append("</table>")
|
| 146 |
+
|
| 147 |
+
return "\n".join(html)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
if __name__ == '__main__':
|
| 151 |
+
# --- IMPORTANT ---
|
| 152 |
+
# Change this to the path of your main results folder.
|
| 153 |
+
# You can use "." if the script is in the same parent folder as the "inference=..." folders.
|
| 154 |
+
logs_directory = './logs_fid/sweeps/test_3'
|
| 155 |
+
|
| 156 |
+
markdown_table = parse_log_files(logs_directory)
|
| 157 |
+
print(markdown_table)
|
| 158 |
+
|
compose.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: mapperatorinator
|
| 2 |
+
services:
|
| 3 |
+
mapperatorinator:
|
| 4 |
+
stdin_open: true
|
| 5 |
+
tty: true
|
| 6 |
+
deploy:
|
| 7 |
+
resources:
|
| 8 |
+
reservations:
|
| 9 |
+
devices:
|
| 10 |
+
- driver: nvidia
|
| 11 |
+
count: all
|
| 12 |
+
capabilities:
|
| 13 |
+
- gpu
|
| 14 |
+
volumes:
|
| 15 |
+
- .:/workspace/Mapperatorinator
|
| 16 |
+
- ../datasets:/workspace/datasets
|
| 17 |
+
network_mode: host
|
| 18 |
+
container_name: mapperatorinator_space
|
| 19 |
+
shm_size: 8gb
|
| 20 |
+
build: .
|
| 21 |
+
# image: my_fixed_image
|
| 22 |
+
command: /bin/bash
|
| 23 |
+
environment:
|
| 24 |
+
- PROJECT_PATH=/workspace/Mapperatorinator
|
| 25 |
+
- WANDB_API_KEY=${WANDB_API_KEY}
|
config.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import Any, Optional
|
| 3 |
+
|
| 4 |
+
from hydra.core.config_store import ConfigStore
|
| 5 |
+
from omegaconf import MISSING
|
| 6 |
+
|
| 7 |
+
from osuT5.osuT5.config import TrainConfig
|
| 8 |
+
from osuT5.osuT5.tokenizer import ContextType
|
| 9 |
+
from osu_diffusion.config import DiffusionTrainConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# BeatHeritage V1 Config Sections
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class AdvancedFeaturesConfig:
|
| 16 |
+
enable_context_aware_generation: bool = True
|
| 17 |
+
enable_style_preservation: bool = True
|
| 18 |
+
enable_difficulty_scaling: bool = True
|
| 19 |
+
enable_pattern_variety: bool = True
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class QualityControlConfig:
|
| 23 |
+
min_distance_threshold: int = 20
|
| 24 |
+
max_overlap_ratio: float = 0.15
|
| 25 |
+
enable_auto_correction: bool = True
|
| 26 |
+
enable_flow_optimization: bool = True
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class PerformanceConfig:
|
| 30 |
+
use_flash_attention: bool = False
|
| 31 |
+
batch_size: int = 1
|
| 32 |
+
max_sequence_length: int = 5120
|
| 33 |
+
cache_size: int = 4096
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class MetadataConfig:
|
| 37 |
+
preserve_timing_points: bool = True
|
| 38 |
+
preserve_bookmarks: bool = True
|
| 39 |
+
auto_detect_kiai: bool = True
|
| 40 |
+
smart_hitsounding: bool = True
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class PostprocessorConfig:
|
| 44 |
+
use_custom: bool = True
|
| 45 |
+
class_name: str = 'beatheritage_postprocessor.BeatHeritagePostprocessor'
|
| 46 |
+
config_class: str = 'beatheritage_postprocessor.BeatHeritageConfig'
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class IntegrationsConfig:
|
| 50 |
+
mai_mod_enhanced: bool = True
|
| 51 |
+
fid_evaluation: bool = True
|
| 52 |
+
benchmark_mode: bool = False
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Default config here based on V28
|
| 56 |
+
|
| 57 |
+
@dataclass
|
| 58 |
+
class InferenceConfig:
|
| 59 |
+
model_path: str = '' # Path to trained model
|
| 60 |
+
audio_path: str = '' # Path to input audio
|
| 61 |
+
output_path: str = '' # Path to output directory
|
| 62 |
+
beatmap_path: str = '' # Path to .osu file to autofill metadata and use as reference
|
| 63 |
+
|
| 64 |
+
# Conditional generation settings
|
| 65 |
+
gamemode: Optional[int] = None # Gamemode of the beatmap
|
| 66 |
+
beatmap_id: Optional[int] = None # Beatmap ID to use as style
|
| 67 |
+
difficulty: Optional[float] = None # Difficulty star rating to map
|
| 68 |
+
mapper_id: Optional[int] = None # Mapper ID to use as style
|
| 69 |
+
year: Optional[int] = None # Year to use as style
|
| 70 |
+
hitsounded: Optional[bool] = None # Whether the beatmap has hitsounds
|
| 71 |
+
keycount: Optional[int] = None # Number of keys to use for mania
|
| 72 |
+
hold_note_ratio: Optional[float] = None # Ratio of how many hold notes to generate in mania
|
| 73 |
+
scroll_speed_ratio: Optional[float] = None # Ratio of how many scroll speed changes to generate in mania and taiko
|
| 74 |
+
descriptors: Optional[list[str]] = None # List of descriptors to use for style
|
| 75 |
+
negative_descriptors: Optional[list[str]] = None # List of descriptors to avoid when using classifier-free guidance
|
| 76 |
+
|
| 77 |
+
# Difficulty settings
|
| 78 |
+
hp_drain_rate: Optional[float] = None # HP drain rate (HP)
|
| 79 |
+
circle_size: Optional[float] = None # Circle size (CS)
|
| 80 |
+
overall_difficulty: Optional[float] = None # Overall difficulty (OD)
|
| 81 |
+
approach_rate: Optional[float] = None # Approach rate (AR)
|
| 82 |
+
slider_multiplier: Optional[float] = None # Multiplier for slider velocity
|
| 83 |
+
slider_tick_rate: Optional[float] = None # Rate of slider ticks
|
| 84 |
+
|
| 85 |
+
# Inference settings
|
| 86 |
+
seed: Optional[int] = None # Random seed
|
| 87 |
+
device: str = 'auto' # Inference device (cpu/cuda/mps/auto)
|
| 88 |
+
precision: str = 'fp32' # Lower precision for speed (fp32/bf16/amp)
|
| 89 |
+
add_to_beatmap: bool = False # Add generated content to the reference beatmap
|
| 90 |
+
export_osz: bool = False # Export beatmap as .osz file
|
| 91 |
+
start_time: Optional[int] = None # Start time of audio to generate beatmap for
|
| 92 |
+
end_time: Optional[int] = None # End time of audio to generate beatmap for
|
| 93 |
+
lookback: float = 0.5 # Fraction of audio sequence to fill with tokens from previous inference window
|
| 94 |
+
lookahead: float = 0.4 # Fraction of audio sequence to skip at the end of the audio window
|
| 95 |
+
timing_leniency: int = 20 # Number of milliseconds of error to allow for timing generation
|
| 96 |
+
in_context: list[ContextType] = field(default_factory=lambda: [ContextType.NONE]) # Context types of other beatmap(s)
|
| 97 |
+
output_type: list[ContextType] = field(default_factory=lambda: [ContextType.MAP]) # Output type (map, timing)
|
| 98 |
+
cfg_scale: float = 1.0 # Scale of classifier-free guidance
|
| 99 |
+
temperature: float = 1.0 # Sampling temperature
|
| 100 |
+
timing_temperature: float = 0.1 # Sampling temperature for timing
|
| 101 |
+
mania_column_temperature: float = 0.5 # Sampling temperature for mania columns
|
| 102 |
+
taiko_hit_temperature: float = 0.5 # Sampling temperature for taiko hit types
|
| 103 |
+
timeshift_bias: float = 0.0 # Logit bias for sampling timeshift tokens
|
| 104 |
+
top_p: float = 0.95 # Top-p sampling threshold
|
| 105 |
+
top_k: int = 0 # Top-k sampling threshold
|
| 106 |
+
repetition_penalty: float = 1.0 # Repetition penalty to reduce repetitive patterns
|
| 107 |
+
parallel: bool = False # Use parallel sampling
|
| 108 |
+
do_sample: bool = True # Use sampling
|
| 109 |
+
num_beams: int = 1 # Number of beams for beam search
|
| 110 |
+
super_timing: bool = False # Use super timing generator (slow but accurate timing)
|
| 111 |
+
timer_num_beams: int = 2 # Number of beams for beam search
|
| 112 |
+
timer_bpm_threshold: float = 0.7 # Threshold requirement for BPM change in timer, higher values will result in less BPM changes
|
| 113 |
+
timer_cfg_scale: float = 1.0 # Scale of classifier-free guidance for timer
|
| 114 |
+
timer_iterations: int = 20 # Number of iterations for timer
|
| 115 |
+
use_server: bool = True # Use server for optimized multiprocess inference
|
| 116 |
+
max_batch_size: int = 16 # Maximum batch size for inference (only used for parallel sampling or super timing)
|
| 117 |
+
resnap_events: bool = True # Resnap notes to the timing after generation
|
| 118 |
+
position_refinement: bool = False # Use position refinement
|
| 119 |
+
|
| 120 |
+
# Metadata settings
|
| 121 |
+
bpm: int = 120 # Beats per minute of input audio
|
| 122 |
+
offset: int = 0 # Start of beat, in miliseconds, from the beginning of input audio
|
| 123 |
+
title: str = '' # Song title
|
| 124 |
+
artist: str = '' # Song artist
|
| 125 |
+
creator: str = '' # Beatmap creator
|
| 126 |
+
version: str = '' # Beatmap version
|
| 127 |
+
background: Optional[str] = None # File name of background image
|
| 128 |
+
preview_time: int = -1 # Time in milliseconds to start previewing the song
|
| 129 |
+
|
| 130 |
+
# Diffusion settings
|
| 131 |
+
generate_positions: bool = True # Use diffusion to generate object positions
|
| 132 |
+
diff_cfg_scale: float = 1.0 # Scale of classifier-free guidance
|
| 133 |
+
compile: bool = False # PyTorch 2.0 optimization
|
| 134 |
+
pad_sequence: bool = False # Pad sequence to max_seq_len
|
| 135 |
+
diff_ckpt: str = '' # Path to checkpoint for diffusion model
|
| 136 |
+
diff_refine_ckpt: str = '' # Path to checkpoint for refining diffusion model
|
| 137 |
+
beatmap_idx: str = 'osu_diffusion/beatmap_idx.pickle' # Path to beatmap index
|
| 138 |
+
refine_iters: int = 10 # Number of refinement iterations
|
| 139 |
+
random_init: bool = False # Whether to initialize with random noise instead of positions generated by the previous model
|
| 140 |
+
timesteps: list[int] = field(default_factory=lambda: [100, 0, 0, 0, 0, 0, 0, 0, 0, 0]) # The number of timesteps we want to take from equally-sized portions of the original process
|
| 141 |
+
max_seq_len: int = 1024 # Maximum sequence length for diffusion
|
| 142 |
+
overlap_buffer: int = 128 # Buffer zone at start and end of sequence to avoid edge effects (should be less than half of max_seq_len)
|
| 143 |
+
|
| 144 |
+
# Training settings
|
| 145 |
+
train: TrainConfig = field(default_factory=TrainConfig) # Training settings for osuT5 model
|
| 146 |
+
diffusion: DiffusionTrainConfig = field(default_factory=DiffusionTrainConfig) # Training settings for diffusion model
|
| 147 |
+
|
| 148 |
+
# BeatHeritage V1 Config Sections
|
| 149 |
+
advanced_features: AdvancedFeaturesConfig = field(default_factory=AdvancedFeaturesConfig)
|
| 150 |
+
quality_control: QualityControlConfig = field(default_factory=QualityControlConfig)
|
| 151 |
+
performance: PerformanceConfig = field(default_factory=PerformanceConfig)
|
| 152 |
+
metadata: MetadataConfig = field(default_factory=MetadataConfig)
|
| 153 |
+
postprocessor: PostprocessorConfig = field(default_factory=PostprocessorConfig)
|
| 154 |
+
integrations: IntegrationsConfig = field(default_factory=IntegrationsConfig)
|
| 155 |
+
hydra: Any = MISSING
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@dataclass
|
| 159 |
+
class FidConfig:
|
| 160 |
+
device: str = 'auto' # Inference device (cpu/cuda/mps/auto)
|
| 161 |
+
compile: bool = True
|
| 162 |
+
num_processes: int = 3
|
| 163 |
+
seed: int = 0
|
| 164 |
+
|
| 165 |
+
skip_generation: bool = False
|
| 166 |
+
fid: bool = True
|
| 167 |
+
rhythm_stats: bool = True
|
| 168 |
+
|
| 169 |
+
dataset_type: str = 'ors'
|
| 170 |
+
dataset_path: str = '/workspace/datasets/ORS16291'
|
| 171 |
+
dataset_start: int = 16200
|
| 172 |
+
dataset_end: int = 16291
|
| 173 |
+
gamemodes: list[int] = field(default_factory=lambda: [0]) # List of gamemodes to include in the dataset
|
| 174 |
+
|
| 175 |
+
classifier_ckpt: str = 'OliBomby/osu-classifier'
|
| 176 |
+
classifier_batch_size: int = 16
|
| 177 |
+
|
| 178 |
+
training_set_ids_path: Optional[str] = None # Path to training set beatmap IDs
|
| 179 |
+
|
| 180 |
+
inference: InferenceConfig = field(default_factory=InferenceConfig) # Training settings for osuT5 model
|
| 181 |
+
hydra: Any = MISSING
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@dataclass
|
| 185 |
+
class MaiModConfig:
|
| 186 |
+
beatmap_path: str = '' # Path to .osu file
|
| 187 |
+
audio_path: str = '' # Path to input audio
|
| 188 |
+
raw_output: bool = False
|
| 189 |
+
precision: str = 'fp32' # Lower precision for speed (fp32/bf16/amp)
|
| 190 |
+
inference: InferenceConfig = field(default_factory=InferenceConfig) # Training settings for osuT5 model
|
| 191 |
+
hydra: Any = MISSING
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
cs = ConfigStore.instance()
|
| 195 |
+
cs.store(group="inference", name="base", node=InferenceConfig)
|
| 196 |
+
cs.store(name="base_fid", node=FidConfig)
|
| 197 |
+
cs.store(name="base_mai_mod", node=MaiModConfig)
|
configs/calc_fid.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- base_fid
|
| 3 |
+
- inference: tiny_dist7
|
| 4 |
+
- _self_
|
| 5 |
+
|
| 6 |
+
compile: false
|
| 7 |
+
num_processes: 32
|
| 8 |
+
seed: 0
|
| 9 |
+
|
| 10 |
+
skip_generation: false
|
| 11 |
+
fid: true
|
| 12 |
+
rhythm_stats: true
|
| 13 |
+
|
| 14 |
+
classifier_ckpt: 'OliBomby/osu-classifier'
|
| 15 |
+
classifier_batch_size: 32
|
| 16 |
+
|
| 17 |
+
training_set_ids_path: null
|
| 18 |
+
|
| 19 |
+
dataset_type: "mmrs"
|
| 20 |
+
dataset_path: C:/Users/Olivier/Documents/Collections/Beatmap ML Datasets/MMRS2025
|
| 21 |
+
dataset_start: 0
|
| 22 |
+
dataset_end: 106 # Contains 324 std beatmaps
|
| 23 |
+
gamemodes: [0] # List of gamemodes to include in the dataset
|
| 24 |
+
|
| 25 |
+
inference:
|
| 26 |
+
super_timing: false
|
| 27 |
+
temperature: 0.9 # Sampling temperature
|
| 28 |
+
top_p: 0.9 # Top-p sampling threshold
|
| 29 |
+
lookback: 0.5 # Fraction of audio sequence to fill with tokens from previous inference window
|
| 30 |
+
lookahead: 0.4 # Fraction of audio sequence to skip at the end of the audio window
|
| 31 |
+
year: 2023
|
| 32 |
+
resnap_events: false
|
| 33 |
+
use_server: false
|
| 34 |
+
|
| 35 |
+
hydra:
|
| 36 |
+
job:
|
| 37 |
+
chdir: True
|
| 38 |
+
run:
|
| 39 |
+
# dir: ./logs_fid/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 40 |
+
dir: ./logs_fid/test
|
| 41 |
+
sweep:
|
| 42 |
+
dir: ./logs_fid/sweeps/test_3
|
| 43 |
+
subdir: ${hydra.job.override_dirname}
|