fourmansyah hongminh54 commited on
Commit
12a8e0f
·
0 Parent(s):

Duplicate from hongminh54/BeatHeritage-v1

Browse files

Co-authored-by: hongminh54 <hongminh54@users.noreply.huggingface.co>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .devcontainer/devcontainer.json +42 -0
  2. .devcontainer/docker-compose.yml +26 -0
  3. .gitattributes +35 -0
  4. .github/FUNDING.yml +15 -0
  5. .gitignore +11 -0
  6. Dockerfile +8 -0
  7. LICENSE +21 -0
  8. README.md +323 -0
  9. audit_all_configs.py +157 -0
  10. beatheritage_postprocessor.py +474 -0
  11. benchmark_comparison.py +469 -0
  12. calc_fid.py +417 -0
  13. classifier/README.md +34 -0
  14. classifier/classify.py +175 -0
  15. classifier/configs/inference.yaml +14 -0
  16. classifier/configs/model/model.yaml +9 -0
  17. classifier/configs/model/whisper_base.yaml +6 -0
  18. classifier/configs/model/whisper_base_v2.yaml +7 -0
  19. classifier/configs/model/whisper_small.yaml +6 -0
  20. classifier/configs/model/whisper_tiny.yaml +6 -0
  21. classifier/configs/train.yaml +82 -0
  22. classifier/configs/train_v1.yaml +4 -0
  23. classifier/configs/train_v2.yaml +14 -0
  24. classifier/configs/train_v3.yaml +17 -0
  25. classifier/count_classes.py +56 -0
  26. classifier/libs/__init__.py +1 -0
  27. classifier/libs/dataset/__init__.py +3 -0
  28. classifier/libs/dataset/data_utils.py +308 -0
  29. classifier/libs/dataset/ors_dataset.py +490 -0
  30. classifier/libs/dataset/osu_parser.py +460 -0
  31. classifier/libs/model/__init__.py +1 -0
  32. classifier/libs/model/model.py +145 -0
  33. classifier/libs/model/spectrogram.py +55 -0
  34. classifier/libs/tokenizer/__init__.py +2 -0
  35. classifier/libs/tokenizer/event.py +53 -0
  36. classifier/libs/tokenizer/tokenizer.py +201 -0
  37. classifier/libs/utils/__init__.py +1 -0
  38. classifier/libs/utils/model_utils.py +190 -0
  39. classifier/libs/utils/routed_pickle.py +17 -0
  40. classifier/test.py +32 -0
  41. classifier/train.py +82 -0
  42. cli_inference.sh +491 -0
  43. colab/beatheritage_v1_inference.ipynb +510 -0
  44. colab/classifier_classify.ipynb +133 -0
  45. colab/mai_mod_inference.ipynb +148 -0
  46. colab/mapperatorinator_inference.ipynb +305 -0
  47. collate_results.py +158 -0
  48. compose.yaml +25 -0
  49. config.py +197 -0
  50. configs/calc_fid.yaml +43 -0
.devcontainer/devcontainer.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // For format details, see https://aka.ms/devcontainer.json. For config options, see the
2
+ // README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-docker-compose
3
+ {
4
+ "name": "Existing Docker Compose (Extend)",
5
+
6
+ // Update the 'dockerComposeFile' list if you have more compose files or use different names.
7
+ // The .devcontainer/docker-compose.yml file contains any overrides you need/want to make.
8
+ "dockerComposeFile": [
9
+ "../compose.yaml",
10
+ "docker-compose.yml"
11
+ ],
12
+
13
+ // The 'service' property is the name of the service for the container that VS Code should
14
+ // use. Update this value and .devcontainer/docker-compose.yml to the real service name.
15
+ "service": "Mapperatorinator",
16
+
17
+ // The optional 'workspaceFolder' property is the path VS Code should open by default when
18
+ // connected. This is typically a file mount in .devcontainer/docker-compose.yml
19
+ "workspaceFolder": "/workspace/Mapperatorinator",
20
+ // "workspaceFolder": "/",
21
+
22
+ // Features to add to the dev container. More info: https://containers.dev/features.
23
+ // "features": {},
24
+
25
+ // Use 'forwardPorts' to make a list of ports inside the container available locally.
26
+ // "forwardPorts": [],
27
+
28
+ // Uncomment the next line if you want start specific services in your Docker Compose config.
29
+ // "runServices": [],
30
+
31
+ // Uncomment the next line if you want to keep your containers running after VS Code shuts down.
32
+ // "shutdownAction": "none",
33
+
34
+ // Uncomment the next line to run commands after the container is created.
35
+ "postCreateCommand": "git config --global --add safe.directory /workspace/Mapperatorinator"
36
+
37
+ // Configure tool-specific properties.
38
+ // "customizations": {},
39
+
40
+ // Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root.
41
+ // "remoteUser": "devcontainer"
42
+ }
.devcontainer/docker-compose.yml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+ services:
3
+ # Update this to the name of the service you want to work with in your docker-compose.yml file
4
+ mapperatorinator:
5
+ # Uncomment if you want to override the service's Dockerfile to one in the .devcontainer
6
+ # folder. Note that the path of the Dockerfile and context is relative to the *primary*
7
+ # docker-compose.yml file (the first in the devcontainer.json "dockerComposeFile"
8
+ # array). The sample below assumes your primary file is in the root of your project.
9
+ #
10
+ # build:
11
+ # context: .
12
+ # dockerfile: .devcontainer/Dockerfile
13
+
14
+ volumes:
15
+ # Update this to wherever you want VS Code to mount the folder of your project
16
+ - ..:/workspace:cached
17
+
18
+ # Uncomment the next four lines if you will use a ptrace-based debugger like C++, Go, and Rust.
19
+ # cap_add:
20
+ # - SYS_PTRACE
21
+ # security_opt:
22
+ # - seccomp:unconfined
23
+
24
+ # Overrides default command so things don't shut down after the process ends.
25
+ command: /bin/sh -c "while sleep 1000; do :; done"
26
+
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.github/FUNDING.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # These are supported funding model platforms
2
+
3
+ github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
4
+ patreon: # Replace with a single Patreon username
5
+ open_collective: # Replace with a single Open Collective username
6
+ ko_fi: OliBomby
7
+ tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8
+ community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9
+ liberapay: # Replace with a single Liberapay username
10
+ issuehunt: # Replace with a single IssueHunt username
11
+ lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
12
+ polar: # Replace with a single Polar username
13
+ buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
14
+ thanks_dev: # Replace with a single thanks.dev username
15
+ custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .venv
2
+ __pycache__
3
+ logs
4
+ logs_fid
5
+ multirun
6
+ tensorboard_logs
7
+ .idea
8
+ test
9
+ test_inference.py
10
+ test_inference_mai_mod.py
11
+ .windsurf
Dockerfile ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
2
+
3
+ RUN apt-get -y update && apt-get -y upgrade && apt-get install -y git && apt-get install -y --no-install-recommends ffmpeg && rm -rf /var/lib/apt/lists/*
4
+ RUN pip install accelerate pydub nnAudio PyYAML transformers hydra-core tensorboard lightning pandas pyarrow einops 'git+https://github.com/OliBomby/slider.git@gedagedigedagedaoh#egg=slider' torch_tb_profiler wandb ninja
5
+ RUN MAX_JOBS=4 pip install flash-attn --no-build-isolation
6
+
7
+ # Modify .bashrc to include the custom prompt
8
+ RUN echo 'if [ -f /.dockerenv ]; then export PS1="(docker) $PS1"; fi' >> /root/.bashrc
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 OliBomby
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BeatHeritage
2
+
3
+ 🎯 **NEW: BeatHeritage V1 - Enhanced Stability & Quality** | [Try on Colab](https://colab.research.google.com/github/hongminh54/BeatHeritage/blob/main/colab/beatheritage_v1_inference.ipynb) | [Documentation](docs/BEATHERITAGE_V1.md)
4
+
5
+ Try the generative model [here](https://colab.research.google.com/github/hongminh54/BeatHeritage/blob/main/colab/beatheritage_v1_inference.ipynb), or MaiMod [here](https://colab.research.google.com/github/OliBomby/Mapperatorinator/blob/main/colab/mai_mod_inference.ipynb). Check out a video showcase [here](https://youtu.be/FEr7t1L2EoA).
6
+
7
+ BeatHeritage (formerly Mapperatorinator) is a multi-model framework that uses spectrogram inputs to generate fully featured osu! beatmaps for all gamemodes and [assist modding beatmaps](#maimod-the-ai-driven-modding-tool).
8
+ The goal of this project is to automatically generate rankable quality osu! beatmaps from any song with a high degree of customizability.
9
+
10
+ ## 🚀 What's New in BeatHeritage V1
11
+
12
+ - **Enhanced Stability**: Optimized sampling parameters (temperature 0.85, top_p 0.92) for more consistent generation
13
+ - **Quality Control**: Automatic spacing correction, overlap detection, and flow optimization
14
+ - **Pattern Variety**: Advanced pattern generation with diversity enhancement
15
+ - **All Gamemodes**: Full support for std, taiko, ctb, and mania with mode-specific optimizations
16
+ - **Performance**: Flash attention, mixed precision (BF16), and gradient checkpointing
17
+ - **Custom Postprocessor**: Advanced post-processing with flow optimization and style preservation
18
+ - **Benchmark Tools**: Compare performance with previous models
19
+ - **Easy Setup**: Auto-setup script with model downloading from Hugging Face
20
+
21
+ This project is built upon [osuT5](https://github.com/gyataro/osuT5) and [osu-diffusion](https://github.com/OliBomby/osu-diffusion). In developing this, I spent about 2500 hours of GPU compute across 142 runs on my 4060 Ti and rented 4090 instances on vast.ai.
22
+
23
+ #### Use this tool responsibly. Always disclose the use of AI in your beatmaps.
24
+
25
+ ## Installation
26
+
27
+ The instruction below allows you to generate beatmaps on your local machine, alternatively you can run it in the cloud with the [colab notebook](https://colab.research.google.com/github/OliBomby/Mapperatorinator/blob/main/colab/mapperatorinator_inference.ipynb).
28
+
29
+ ### 1. Clone the repository
30
+
31
+ ```sh
32
+ git clone https://github.com/OliBomby/Mapperatorinator.git
33
+ cd Mapperatorinator
34
+ ```
35
+
36
+ ### 2. (Optional) Create virtual environment
37
+
38
+ Use Python 3.10, later versions will might not be compatible with the dependencies.
39
+
40
+ ```sh
41
+ python -m venv .venv
42
+
43
+ # In cmd.exe
44
+ .venv\Scripts\activate.bat
45
+ # In PowerShell
46
+ .venv\Scripts\Activate.ps1
47
+ # In Linux or MacOS
48
+ source .venv/bin/activate
49
+ ```
50
+
51
+ ### 3. Install dependencies
52
+
53
+ - Python 3.10
54
+ - [Git](https://git-scm.com/downloads)
55
+ - [ffmpeg](http://www.ffmpeg.org/)
56
+ - [PyTorch](https://pytorch.org/get-started/locally/): Make sure to follow the Get Started guide so you install `torch` and `torchaudio` with GPU support.
57
+
58
+ - and the remaining Python dependencies:
59
+
60
+ ```sh
61
+ pip install -r requirements.txt
62
+ ```
63
+
64
+ ## Web GUI (Recommended)
65
+
66
+ For a more user-friendly experience, consider using the Web UI. It provides a graphical interface to configure generation parameters, start the process, and monitor the output.
67
+
68
+ ### Launch the GUI
69
+
70
+ Navigate to the cloned `Mapperatorinator` directory in your terminal and run:
71
+
72
+ ```sh
73
+ python web-ui.py
74
+ ```
75
+
76
+ This will start a local web server and automatically open the UI in a new window.
77
+
78
+ ### Using the GUI
79
+
80
+ - **Configure:** Set input/output paths using the form fields and "Browse" buttons. Adjust generation parameters like gamemode, difficulty, style (year, mapper ID, descriptors), timing, specific features (hitsounds, super timing), and more, mirroring the command-line options. (Note: If you provide a `beatmap_path`, the UI will automatically determine the `audio_path` and `output_path` from it, so you can leave those fields blank)
81
+ - **Start:** Click the "Start Inference" button to begin the beatmap generation.
82
+ - **Cancel:** You can stop the ongoing process using the "Cancel Inference" button.
83
+ - **Open Output:** Once finished, use the "Open Output Folder" button for quick access to the generated files.
84
+
85
+ The Web UI acts as a convenient wrapper around the `inference.py` script. For advanced options or troubleshooting, refer to the command-line instructions.
86
+
87
+ ![python_u3zyW0S3Vs](https://github.com/user-attachments/assets/5312a45f-d51c-4b37-9389-da3258ddd0a1)
88
+
89
+ ## Command-Line Inference
90
+
91
+ For users who prefer the command line or need access to advanced configurations, follow the steps below. **Note:** For a simpler graphical interface, please see the [Web UI (Recommended)](#web-ui-recommended) section above.
92
+
93
+ Run `inference.py` and pass in some arguments to generate beatmaps. For this use [Hydra override syntax](https://hydra.cc/docs/advanced/override_grammar/basic/). See `configs/inference_v29.yaml` for all available parameters.
94
+ ```
95
+ python inference.py \
96
+ audio_path [Path to input audio] \
97
+ output_path [Path to output directory] \
98
+ beatmap_path [Path to .osu file to autofill metadata, and output_path, or use as reference] \
99
+
100
+ gamemode [Game mode to generate 0=std, 1=taiko, 2=ctb, 3=mania] \
101
+ difficulty [Difficulty star rating to generate] \
102
+ mapper_id [Mapper user ID for style] \
103
+ year [Upload year to simulate] \
104
+ hitsounded [Whether to add hitsounds] \
105
+ slider_multiplier [Slider velocity multiplier] \
106
+ circle_size [Circle size] \
107
+ keycount [Key count for mania] \
108
+ hold_note_ratio [Hold note ratio for mania 0-1] \
109
+ scroll_speed_ratio [Scroll speed ratio for mania and ctb 0-1] \
110
+ descriptors [List of beatmap user tags for style] \
111
+ negative_descriptors [List of beatmap user tags for classifier-free guidance] \
112
+
113
+ add_to_beatmap [Whether to add generated content to the reference beatmap instead of making a new beatmap] \
114
+ start_time [Generation start time in milliseconds] \
115
+ end_time [Generation end time in milliseconds] \
116
+ in_context [List of additional context to provide to the model [NONE,TIMING,KIAI,MAP,GD,NO_HS]] \
117
+ output_type [List of content types to generate] \
118
+ cfg_scale [Scale of the classifier-free guidance] \
119
+ super_timing [Whether to use slow accurate variable BPM timing generator] \
120
+ seed [Random seed for generation] \
121
+ ```
122
+
123
+ Example:
124
+ ```
125
+ python inference.py beatmap_path="'C:\Users\USER\AppData\Local\osu!\Songs\1 Kenji Ninuma - DISCO PRINCE\Kenji Ninuma - DISCOPRINCE (peppy) [Normal].osu'" gamemode=0 difficulty=5.5 year=2023 descriptors="['jump aim','clean']" in_context=[TIMING,KIAI]
126
+ ```
127
+
128
+ ## Interactive CLI
129
+ For those who prefer a terminal-based workflow but want a guided setup, the interactive CLI script is an excellent alternative to the Web UI.
130
+
131
+ ### Launch the CLI
132
+ Navigate to the cloned directory. You may need to make the script executable first.
133
+
134
+ ```sh
135
+ # Make the script executable (only needs to be done once)
136
+ chmod +x cli_inference.sh
137
+ ```
138
+
139
+ ```sh
140
+ # Run the script
141
+ ./cli_inference.sh
142
+ ```
143
+
144
+ ### Using the CLI
145
+ The script will walk you through a series of prompts to configure all generation parameters, just like the Web UI.
146
+
147
+ It uses a color-coded interface for clarity.
148
+ It provides an advanced multi-select menu for choosing style descriptors using your arrow keys and spacebar.
149
+ After you've answered all the questions, it will display the final command for your review.
150
+ You can then confirm to execute it directly or cancel and copy the command for manual use.
151
+
152
+ ## Generation Tips
153
+
154
+ - You can edit `configs/inference_v29.yaml` and add your arguments there instead of typing them in the terminal every time.
155
+ - All available descriptors can be found [here](https://osu.ppy.sh/wiki/en/Beatmap/Beatmap_tags).
156
+ - Always provide a year argument between 2007 and 2023. If you leave it unknown, the model might generate with an inconsistent style.
157
+ - Always provide a difficulty argument. If you leave it unknown, the model might generate with an inconsistent difficulty.
158
+ - Increase the `cfg_scale` parameter to increase the effectiveness of the `mapper_id` and `descriptors` arguments.
159
+ - You can use the `negative_descriptors` argument to guide the model away from certain styles. This only works when `cfg_scale > 1`. Make sure the number of negative descriptors is equal to the number of descriptors.
160
+ - If your song style and desired beatmap style don't match well, the model might not follow your directions. For example, its hard to generate a high SR, high SV beatmap for a calm song.
161
+ - If you already have timing and kiai times done for a song, then you can give this to the model to greatly increase inference speed and accuracy: Use the `beatmap_path` and `in_context=[TIMING,KIAI]` arguments.
162
+ - To remap just a part of your beatmap, use the `beatmap_path`, `start_time`, `end_time`, and `add_to_beatmap=true` arguments.
163
+ - To generate a guest difficulty for a beatmap, use the `beatmap_path` and `in_context=[GD,TIMING,KIAI]` arguments.
164
+ - To generate hitsounds for a beatmap, use the `beatmap_path` and `in_context=[NO_HS,TIMING,KIAI]` arguments.
165
+ - To generate only timing for a song, use the `super_timing=true` and `output_type=[TIMING]` arguments.
166
+
167
+ ## MaiMod: The AI-driven Modding Tool
168
+
169
+ MaiMod is a modding tool for osu! beatmaps that uses Mapperatorinator predictions to find potential faults and inconsistencies which can't be detected by other automatic modding tools like [Mapset Verifier](https://github.com/Naxesss/MapsetVerifier).
170
+ It can detect issues like:
171
+ - Incorrect snapping or rhythmic patterns
172
+ - Inaccurate timing points
173
+ - Inconsistent hit object positions or new combo placements
174
+ - Weird slider shapes
175
+ - Inconsistent hitsounds or volumes
176
+
177
+ You can try MaiMod [here](https://colab.research.google.com/github/OliBomby/Mapperatorinator/blob/main/colab/mai_mod_inference.ipynb), or run it locally:
178
+ To run MaiMod locally, you'll need to install Mapperatorinator. Then, run the `mai_mod.py` script, specifying your beatmap's path with the `beatmap_path` argument.
179
+ ```sh
180
+ python mai_mod.py beatmap_path="'C:\Users\USER\AppData\Local\osu!\Songs\1 Kenji Ninuma - DISCO PRINCE\Kenji Ninuma - DISCOPRINCE (peppy) [Normal].osu'"
181
+ ```
182
+ This will print the modding suggestions to the console, which you can then apply to your beatmap manually.
183
+ Suggestions are ordered chronologically and grouped into categories.
184
+ The first value in the circle indicates the 'surprisal' which is a measure of how unexpected the model found the issue to be, so you can prioritize the most important issues.
185
+
186
+ The model can make mistakes, especially on low surprisal issues, so always double-check the suggestions before applying them to your beatmap.
187
+ The main goal is to help you narrow down the search space for potential issues, so you don't have to manually check every single hit object in your beatmap.
188
+
189
+ ### MaiMod GUI
190
+ To run the MaiMod Web UI, you'll need to install Mapperatorinator.
191
+ Then, run the `mai_mod_ui.py` script. This will start a local web server and automatically open the UI in a new window:
192
+
193
+ ```sh
194
+ python mai_mod_ui.py
195
+ ```
196
+
197
+ <img width="850" height="1019" alt="afbeelding" src="https://github.com/user-attachments/assets/67c03a43-a7bd-4265-a5b1-5e4d62aca1fa" />
198
+
199
+ ## Overview
200
+
201
+ ### Tokenization
202
+
203
+ Mapperatorinator converts osu! beatmaps into an intermediate event representation that can be directly converted to and from tokens.
204
+ It includes hit objects, hitsounds, slider velocities, new combos, timing points, kiai times, and taiko/mania scroll speeds.
205
+
206
+ Here is a small examle of the tokenization process:
207
+
208
+ ![mapperatorinator_parser](https://github.com/user-attachments/assets/84efde76-4c27-48a1-b8ce-beceddd9e695)
209
+
210
+ To save on vocabulary size, time events are quantized to 10ms intervals and position coordinates are quantized to 32 pixel grid points.
211
+
212
+ ### Model architecture
213
+ The model is basically a wrapper around the [HF Transformers Whisper](https://huggingface.co/docs/transformers/en/model_doc/whisper#transformers.WhisperForConditionalGeneration) model, with custom input embeddings and loss function.
214
+ Model size amounts to 219M parameters.
215
+ This model was found to be faster and more accurate than T5 for this task.
216
+
217
+ The high-level overview of the model's input-output is as follows:
218
+
219
+ ![Picture2](https://user-images.githubusercontent.com/28675590/201044116-1384ad72-c540-44db-a285-7319dd01caad.svg)
220
+
221
+ The model uses Mel spectrogram frames as encoder input, with one frame per input position. The model decoder output at each step is a softmax distribution over a discrete, predefined, vocabulary of events. Outputs are sparse, events are only needed when a hit-object occurs, instead of annotating every single audio frame.
222
+
223
+ ### Multitask training format
224
+
225
+ ![Multitask training format](https://github.com/user-attachments/assets/62f490bc-a567-4671-a7ce-dbcc5f9cd6d9)
226
+
227
+ Before the SOS token are additional tokens that facilitate conditional generation. These tokens include the gamemode, difficulty, mapper ID, year, and other metadata.
228
+ During training, these tokens do not have accompanying labels, so they are never output by the model.
229
+ Also during training there is a random chance that a metadata token gets replaced by an 'unknown' token, so during inference we can use these 'unknown' tokens to reduce the amount of metadata we have to give to the model.
230
+
231
+ ### Seamless long generation
232
+
233
+ The context length of the model is 8.192 seconds long. This is obviously not enough to generate a full beatmap, so we have to split the song into multiple windows and generate the beatmap in small parts.
234
+ To make sure that the generated beatmap does not have noticeable seams in between windows, we use a 90% overlap and generate the windows sequentially.
235
+ Each generation window except the first starts with the decoder pre-filled up to 50% of the generation window with tokens from the previous windows.
236
+ We use a logit processor to make sure that the model can't generate time tokens that are in the first 50% of the generation window.
237
+ Additionally, the last 40% of the generation window is reserved for the next window. Any generated time tokens in that range are treated as EOS tokens.
238
+ This ensures that each generated token is conditioned on at least 4 seconds of previous tokens and 3.3 seconds of future audio to anticipate.
239
+
240
+ To prevent offset drifting during long generation, random offsets have been added to time events in the decoder during training.
241
+ This forces it to correct timing errors by listening to the onsets in the audio instead, and results in a consistently accurate offset.
242
+
243
+ ### Refined coordinates with diffusion
244
+
245
+ Position coordinates generated by the decoder are quantized to 32 pixel grid points, so afterward we use diffusion to denoise the coordinates to the final positions.
246
+ For this we trained a modified version of [osu-diffusion](https://github.com/OliBomby/osu-diffusion) that is specialized to only the last 10% of the noise schedule, and accepts the more advanced metadata tokens that Mapperatorinator uses for conditional generation.
247
+
248
+ Since the Mapperatorinator model outputs the SV of sliders, the required length of the slider is fixed regardless of the shape of the control point path.
249
+ Therefore, we try to guide the diffusion process to create coordinates that fit the required slider lengths.
250
+ We do this by recalculating the slider end positions after every step of the diffusion process based on the required length and the current control point path.
251
+ This means that the diffusion process does not have direct control over the slider end positions, but it can still influence them by changing the control point path.
252
+
253
+ ### Post-processing
254
+
255
+ Mapperatorinator does some extra post-processing to improve the quality of the generated beatmap:
256
+
257
+ - Refine position coordinates with diffusion.
258
+ - Resnap time events to the nearest tick using the snap divisors generated by the model.
259
+ - Snap near-perfect positional overlaps.
260
+ - Convert mania column events to X coordinates.
261
+ - Generate slider paths for taiko drumrolls.
262
+ - Fix big discrepancies in required slider length and control point path length.
263
+
264
+ ### Super timing generator
265
+
266
+ Super timing generator is an algorithm that improves the precision and accuracy of generated timing by infering timing for the whole song 20 times and averaging the results.
267
+ This is useful for songs with variable BPM, or songs with BPM changes. The result is almost perfect with only sometimes a section that needs manual adjustment.
268
+
269
+ ## Training
270
+
271
+ The instruction below creates a training environment on your local machine.
272
+
273
+ ### 1. Clone the repository
274
+
275
+ ```sh
276
+ git clone https://github.com/OliBomby/Mapperatorinator.git
277
+ cd Mapperatorinator
278
+ ```
279
+
280
+ ### 2. Create dataset
281
+
282
+ Create your own dataset using the [Mapperator console app](https://github.com/mappingtools/Mapperator/blob/master/README.md#create-a-high-quality-dataset). It requires an [osu! OAuth client token](https://osu.ppy.sh/home/account/edit) to verify beatmaps and get additional metadata. Place the dataset in a `datasets` directory next to the `Mapperatorinator` directory.
283
+
284
+ ```sh
285
+ Mapperator.ConsoleApp.exe dataset2 -t "/Mapperatorinator/datasets/beatmap_descriptors.csv" -i "path/to/osz/files" -o "/datasets/cool_dataset"
286
+ ```
287
+
288
+ ### 3. Create docker container
289
+ Training in your venv is also possible, but we recommend using Docker on WSL for better performance.
290
+ ```sh
291
+ docker compose up -d --force-recreate
292
+ docker attach mapperatorinator_space
293
+ ```
294
+
295
+ ### 4. Configure parameters and begin training
296
+
297
+ All configurations are located in `./configs/osut5/train.yaml`. Begin training by calling `osuT5/train.py`.
298
+
299
+ ```sh
300
+ python osuT5/train.py -cn train_v29 train_dataset_path="/workspace/datasets/cool_dataset" test_dataset_path="/workspace/datasets/cool_dataset" train_dataset_end=90 test_dataset_start=90 test_dataset_end=100
301
+ ```
302
+
303
+ ## See also
304
+ - [Mapper Classifier](./classifier/README.md)
305
+ - [RComplexion](./rcomplexion/README.md)
306
+
307
+ ## Credits
308
+
309
+ Special thanks to:
310
+ 1. The authors of [osuT5](https://github.com/gyataro/osuT5) for their training code.
311
+ 2. Hugging Face team for their [tools](https://huggingface.co/docs/transformers/index).
312
+ 3. [Jason Won](https://github.com/jaswon) and [Richard Nagyfi](https://github.com/sedthh) for bouncing ideas.
313
+ 4. [Marvin](https://github.com/minetoblend) for donating training credits.
314
+ 5. The osu! community for the beatmaps.
315
+
316
+ ## Related works
317
+
318
+ 1. [osu! Beatmap Generator](https://github.com/Syps/osu_beatmap_generator) by Syps (Nick Sypteras)
319
+ 2. [osumapper](https://github.com/kotritrona/osumapper) by kotritrona, jyvden, Yoyolick (Ryan Zmuda)
320
+ 3. [osu-diffusion](https://github.com/OliBomby/osu-diffusion) by OliBomby (Olivier Schipper), NiceAesth (Andrei Baciu)
321
+ 4. [osuT5](https://github.com/gyataro/osuT5) by gyataro (Xiwen Teoh)
322
+ 5. [Beat Learning](https://github.com/sedthh/BeatLearning) by sedthh (Richard Nagyfi)
323
+ 6. [osu!dreamer](https://github.com/jaswon/osu-dreamer) by jaswon (Jason Won)
audit_all_configs.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Comprehensive Config Audit Script for BeatHeritage
4
+ Checks all config files against their corresponding dataclass definitions
5
+ """
6
+
7
+ import os
8
+ import yaml
9
+ from pathlib import Path
10
+ from dataclasses import fields
11
+ from typing import Dict, List, Set, Any
12
+
13
+ # Import all config classes
14
+ from config import InferenceConfig, FidConfig, MaiModConfig
15
+ from osuT5.osuT5.config import TrainConfig, DataConfig, DataloaderConfig, OptimizerConfig
16
+ from osu_diffusion.config import DiffusionTrainConfig
17
+
18
+ def get_config_fields(config_class) -> Set[str]:
19
+ """Get all field names from a dataclass"""
20
+ return {field.name for field in fields(config_class)}
21
+
22
+ def get_yaml_keys(yaml_path: str, prefix: str = "") -> Set[str]:
23
+ """Get all keys from a YAML file, including nested keys with dot notation"""
24
+ keys = set()
25
+
26
+ try:
27
+ with open(yaml_path, 'r') as f:
28
+ data = yaml.safe_load(f)
29
+
30
+ def extract_keys(obj, parent_key=""):
31
+ if isinstance(obj, dict):
32
+ for key, value in obj.items():
33
+ if key == 'defaults': # Skip Hydra defaults
34
+ continue
35
+
36
+ full_key = f"{parent_key}.{key}" if parent_key else key
37
+ keys.add(full_key)
38
+
39
+ if isinstance(value, dict):
40
+ extract_keys(value, full_key)
41
+ elif isinstance(value, list) and value and isinstance(value[0], dict):
42
+ # Handle list of dicts
43
+ extract_keys(value[0], full_key)
44
+
45
+ extract_keys(data)
46
+
47
+ except Exception as e:
48
+ print(f"Error reading {yaml_path}: {e}")
49
+
50
+ return keys
51
+
52
+ def audit_config_mapping(config_path: str, config_class, config_name: str):
53
+ """Audit a specific config file against its dataclass"""
54
+ print(f"\n[AUDIT] {config_name}: {config_path}")
55
+
56
+ if not os.path.exists(config_path):
57
+ print(f"[ERROR] Config file not found: {config_path}")
58
+ return
59
+
60
+ # Get fields from dataclass
61
+ class_fields = get_config_fields(config_class)
62
+
63
+ # Get keys from YAML
64
+ yaml_keys = get_yaml_keys(config_path)
65
+
66
+ # Find mismatches
67
+ missing_in_class = yaml_keys - class_fields
68
+ missing_in_config = class_fields - yaml_keys
69
+
70
+ # Filter out nested keys for top-level check
71
+ top_level_yaml = {key.split('.')[0] for key in yaml_keys}
72
+ top_level_missing = top_level_yaml - class_fields
73
+
74
+ print(f"[SUMMARY]:")
75
+ print(f" - Dataclass fields: {len(class_fields)}")
76
+ print(f" - YAML keys (all): {len(yaml_keys)}")
77
+ print(f" - YAML keys (top-level): {len(top_level_yaml)}")
78
+
79
+ if top_level_missing:
80
+ print(f"[MISSING] Keys in YAML but missing in dataclass:")
81
+ for key in sorted(top_level_missing):
82
+ related_keys = [k for k in yaml_keys if k.startswith(key)]
83
+ print(f" - {key} (related: {len(related_keys)} keys)")
84
+ if len(related_keys) <= 5: # Show details for small sections
85
+ for rkey in sorted(related_keys)[:5]:
86
+ print(f" * {rkey}")
87
+ else:
88
+ print(f" * ... and {len(related_keys)-3} more keys")
89
+
90
+ if missing_in_config:
91
+ optional_missing = missing_in_config & {'hydra', 'train', 'diffusion'} # Usually optional
92
+ real_missing = missing_in_config - optional_missing
93
+ if real_missing:
94
+ print(f"[WARNING] Fields in dataclass but missing in YAML:")
95
+ for key in sorted(real_missing):
96
+ print(f" - {key}")
97
+
98
+ return {
99
+ 'missing_in_class': top_level_missing,
100
+ 'missing_in_config': missing_in_config,
101
+ 'all_yaml_keys': yaml_keys,
102
+ 'class_fields': class_fields
103
+ }
104
+
105
+ def main():
106
+ """Run comprehensive config audit"""
107
+ print("BeatHeritage Config Audit - Finding ALL Mismatches")
108
+ print("=" * 60)
109
+
110
+ # Define config mappings
111
+ config_mappings = [
112
+ # Inference configs
113
+ ("configs/inference/beatheritage_v1.yaml", InferenceConfig, "Inference (BeatHeritage V1)"),
114
+ ("configs/inference/default.yaml", InferenceConfig, "Inference (Default)"),
115
+
116
+ # Training configs
117
+ ("configs/train/beatheritage_v1.yaml", TrainConfig, "Training (BeatHeritage V1)"),
118
+ ("configs/train/default.yaml", TrainConfig, "Training (Default)"),
119
+
120
+ # Diffusion configs
121
+ ("configs/diffusion/v1.yaml", DiffusionTrainConfig, "Diffusion (V1)"),
122
+ ]
123
+
124
+ all_issues = {}
125
+
126
+ for config_path, config_class, name in config_mappings:
127
+ issues = audit_config_mapping(config_path, config_class, name)
128
+ if issues and issues['missing_in_class']:
129
+ all_issues[name] = issues
130
+
131
+ # Summary report
132
+ print(f"\nAUDIT SUMMARY")
133
+ print("=" * 60)
134
+
135
+ if not all_issues:
136
+ print("All configs are aligned with their dataclasses!")
137
+ return
138
+
139
+ print(f"Found issues in {len(all_issues)} config(s):")
140
+
141
+ for config_name, issues in all_issues.items():
142
+ print(f"\n{config_name}:")
143
+ for key in sorted(issues['missing_in_class']):
144
+ print(f" - Missing field: {key}")
145
+
146
+ # Generate fix suggestions
147
+ print(f"\nSUGGESTED FIXES")
148
+ print("=" * 60)
149
+
150
+ for config_name, issues in all_issues.items():
151
+ if 'Inference' in config_name:
152
+ print(f"\nFor InferenceConfig class:")
153
+ for key in sorted(issues['missing_in_class']):
154
+ print(f" + {key}: <appropriate_type> = <default_value>")
155
+
156
+ if __name__ == "__main__":
157
+ main()
beatheritage_postprocessor.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BeatHeritage V1 Custom Postprocessor
3
+ Enhanced postprocessing for improved beatmap quality
4
+ """
5
+
6
+ import numpy as np
7
+ from typing import List, Tuple, Dict, Optional
8
+ from dataclasses import dataclass
9
+ import logging
10
+
11
+ from osuT5.osuT5.inference.postprocessor import Postprocessor, BeatmapConfig
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class BeatHeritageConfig(BeatmapConfig):
18
+ """Enhanced config for BeatHeritage V1 postprocessing"""
19
+ # Quality control parameters
20
+ min_distance_threshold: float = 20.0
21
+ max_overlap_ratio: float = 0.15
22
+ enable_auto_correction: bool = True
23
+ enable_flow_optimization: bool = True
24
+
25
+ # Pattern enhancement
26
+ enable_pattern_variety: bool = True
27
+ pattern_complexity_target: float = 0.7
28
+
29
+ # Difficulty scaling
30
+ enable_difficulty_scaling: bool = True
31
+ difficulty_variance_threshold: float = 0.3
32
+
33
+ # Style preservation
34
+ enable_style_preservation: bool = True
35
+ style_consistency_weight: float = 0.8
36
+
37
+
38
+ class BeatHeritagePostprocessor(Postprocessor):
39
+ """Enhanced postprocessor for BeatHeritage V1"""
40
+
41
+ def __init__(self, config: BeatHeritageConfig):
42
+ super().__init__(config)
43
+ self.config = config
44
+ self.flow_optimizer = FlowOptimizer(config)
45
+ self.pattern_enhancer = PatternEnhancer(config)
46
+ self.quality_controller = QualityController(config)
47
+
48
+ def postprocess(self, beatmap_data: Dict) -> Dict:
49
+ """
50
+ Enhanced postprocessing pipeline for BeatHeritage V1
51
+
52
+ Args:
53
+ beatmap_data: Raw beatmap data from model
54
+
55
+ Returns:
56
+ Processed beatmap data with enhancements
57
+ """
58
+ # Base postprocessing
59
+ beatmap_data = super().postprocess(beatmap_data)
60
+
61
+ # Quality control
62
+ if self.config.enable_auto_correction:
63
+ beatmap_data = self.quality_controller.fix_spacing_issues(beatmap_data)
64
+ beatmap_data = self.quality_controller.fix_overlaps(beatmap_data)
65
+
66
+ # Flow optimization
67
+ if self.config.enable_flow_optimization:
68
+ beatmap_data = self.flow_optimizer.optimize_flow(beatmap_data)
69
+
70
+ # Pattern enhancement
71
+ if self.config.enable_pattern_variety:
72
+ beatmap_data = self.pattern_enhancer.enhance_patterns(beatmap_data)
73
+
74
+ # Difficulty scaling
75
+ if self.config.enable_difficulty_scaling:
76
+ beatmap_data = self._scale_difficulty(beatmap_data)
77
+
78
+ # Style preservation
79
+ if self.config.enable_style_preservation:
80
+ beatmap_data = self._preserve_style(beatmap_data)
81
+
82
+ return beatmap_data
83
+
84
+ def _scale_difficulty(self, beatmap_data: Dict) -> Dict:
85
+ """Scale difficulty to match target star rating"""
86
+ target_difficulty = self.config.difficulty
87
+ if target_difficulty is None:
88
+ return beatmap_data
89
+
90
+ current_difficulty = self._calculate_difficulty(beatmap_data)
91
+ scale_factor = target_difficulty / max(current_difficulty, 0.1)
92
+
93
+ # Adjust spacing and timing based on scale factor
94
+ if 'hit_objects' in beatmap_data:
95
+ for obj in beatmap_data['hit_objects']:
96
+ if 'distance' in obj:
97
+ obj['distance'] *= scale_factor
98
+
99
+ logger.info(f"Scaled difficulty from {current_difficulty:.2f} to {target_difficulty:.2f}")
100
+ return beatmap_data
101
+
102
+ def _preserve_style(self, beatmap_data: Dict) -> Dict:
103
+ """Preserve mapping style consistency"""
104
+ # Analyze style characteristics
105
+ style_features = self._extract_style_features(beatmap_data)
106
+
107
+ # Apply style consistency
108
+ consistency_weight = self.config.style_consistency_weight
109
+
110
+ if 'hit_objects' in beatmap_data:
111
+ for i, obj in enumerate(beatmap_data['hit_objects']):
112
+ if i > 0:
113
+ # Maintain consistent spacing patterns
114
+ prev_obj = beatmap_data['hit_objects'][i-1]
115
+ expected_distance = style_features.get('avg_distance', 100)
116
+
117
+ if 'position' in obj and 'position' in prev_obj:
118
+ current_distance = self._calculate_distance(
119
+ obj['position'], prev_obj['position']
120
+ )
121
+
122
+ # Blend current with expected based on consistency weight
123
+ adjusted_distance = (
124
+ current_distance * (1 - consistency_weight) +
125
+ expected_distance * consistency_weight
126
+ )
127
+
128
+ # Adjust position to match distance
129
+ obj['position'] = self._adjust_position(
130
+ prev_obj['position'],
131
+ obj['position'],
132
+ adjusted_distance
133
+ )
134
+
135
+ return beatmap_data
136
+
137
+ def _calculate_difficulty(self, beatmap_data: Dict) -> float:
138
+ """Calculate approximate star rating"""
139
+ # Simplified difficulty calculation
140
+ num_objects = len(beatmap_data.get('hit_objects', []))
141
+ avg_spacing = self._calculate_avg_spacing(beatmap_data)
142
+ bpm = beatmap_data.get('bpm', 180)
143
+
144
+ # Simple formula (can be improved)
145
+ difficulty = (num_objects / 100) * (avg_spacing / 50) * (bpm / 180)
146
+ return min(max(difficulty, 0), 10) # Clamp to 0-10
147
+
148
+ def _extract_style_features(self, beatmap_data: Dict) -> Dict:
149
+ """Extract style characteristics from beatmap"""
150
+ features = {}
151
+
152
+ if 'hit_objects' in beatmap_data:
153
+ distances = []
154
+ for i in range(1, len(beatmap_data['hit_objects'])):
155
+ if 'position' in beatmap_data['hit_objects'][i]:
156
+ dist = self._calculate_distance(
157
+ beatmap_data['hit_objects'][i-1].get('position', (256, 192)),
158
+ beatmap_data['hit_objects'][i]['position']
159
+ )
160
+ distances.append(dist)
161
+
162
+ if distances:
163
+ features['avg_distance'] = np.mean(distances)
164
+ features['distance_variance'] = np.var(distances)
165
+
166
+ return features
167
+
168
+ def _calculate_avg_spacing(self, beatmap_data: Dict) -> float:
169
+ """Calculate average spacing between objects"""
170
+ distances = []
171
+ objects = beatmap_data.get('hit_objects', [])
172
+
173
+ for i in range(1, len(objects)):
174
+ if 'position' in objects[i] and 'position' in objects[i-1]:
175
+ dist = self._calculate_distance(
176
+ objects[i-1]['position'],
177
+ objects[i]['position']
178
+ )
179
+ distances.append(dist)
180
+
181
+ return np.mean(distances) if distances else 100
182
+
183
+ def _calculate_distance(self, pos1: Tuple[float, float],
184
+ pos2: Tuple[float, float]) -> float:
185
+ """Calculate Euclidean distance between two positions"""
186
+ return np.sqrt((pos1[0] - pos2[0])**2 + (pos1[1] - pos2[1])**2)
187
+
188
+ def _adjust_position(self, from_pos: Tuple[float, float],
189
+ to_pos: Tuple[float, float],
190
+ target_distance: float) -> Tuple[float, float]:
191
+ """Adjust position to achieve target distance"""
192
+ current_distance = self._calculate_distance(from_pos, to_pos)
193
+ if current_distance < 0.01: # Avoid division by zero
194
+ return to_pos
195
+
196
+ scale = target_distance / current_distance
197
+ dx = (to_pos[0] - from_pos[0]) * scale
198
+ dy = (to_pos[1] - from_pos[1]) * scale
199
+
200
+ # Keep within playfield bounds
201
+ new_x = max(0, min(512, from_pos[0] + dx))
202
+ new_y = max(0, min(384, from_pos[1] + dy))
203
+
204
+ return (new_x, new_y)
205
+
206
+
207
+ class FlowOptimizer:
208
+ """Optimize flow patterns in beatmaps"""
209
+
210
+ def __init__(self, config: BeatHeritageConfig):
211
+ self.config = config
212
+
213
+ def optimize_flow(self, beatmap_data: Dict) -> Dict:
214
+ """Optimize flow for better playability"""
215
+ if 'hit_objects' not in beatmap_data:
216
+ return beatmap_data
217
+
218
+ objects = beatmap_data['hit_objects']
219
+ optimized_objects = []
220
+
221
+ for i, obj in enumerate(objects):
222
+ if i >= 2 and 'position' in obj:
223
+ # Calculate flow angle
224
+ prev_angle = self._calculate_angle(
225
+ objects[i-2].get('position', (256, 192)),
226
+ objects[i-1].get('position', (256, 192))
227
+ )
228
+ current_angle = self._calculate_angle(
229
+ objects[i-1].get('position', (256, 192)),
230
+ obj['position']
231
+ )
232
+
233
+ # Smooth sharp angles
234
+ angle_diff = abs(current_angle - prev_angle)
235
+ if angle_diff > 120: # Sharp angle threshold
236
+ # Adjust position for smoother flow
237
+ smoothed_angle = prev_angle + np.sign(current_angle - prev_angle) * 90
238
+ distance = self._calculate_distance(
239
+ objects[i-1]['position'],
240
+ obj['position']
241
+ )
242
+
243
+ new_x = objects[i-1]['position'][0] + distance * np.cos(np.radians(smoothed_angle))
244
+ new_y = objects[i-1]['position'][1] + distance * np.sin(np.radians(smoothed_angle))
245
+
246
+ obj['position'] = (
247
+ max(0, min(512, new_x)),
248
+ max(0, min(384, new_y))
249
+ )
250
+
251
+ optimized_objects.append(obj)
252
+
253
+ beatmap_data['hit_objects'] = optimized_objects
254
+ return beatmap_data
255
+
256
+ def _calculate_angle(self, pos1: Tuple[float, float],
257
+ pos2: Tuple[float, float]) -> float:
258
+ """Calculate angle between two positions in degrees"""
259
+ return np.degrees(np.arctan2(pos2[1] - pos1[1], pos2[0] - pos1[0]))
260
+
261
+ def _calculate_distance(self, pos1: Tuple[float, float],
262
+ pos2: Tuple[float, float]) -> float:
263
+ """Calculate Euclidean distance"""
264
+ return np.sqrt((pos1[0] - pos2[0])**2 + (pos1[1] - pos2[1])**2)
265
+
266
+
267
+ class PatternEnhancer:
268
+ """Enhance pattern variety in beatmaps"""
269
+
270
+ def __init__(self, config: BeatHeritageConfig):
271
+ self.config = config
272
+ self.pattern_library = self._load_pattern_library()
273
+
274
+ def enhance_patterns(self, beatmap_data: Dict) -> Dict:
275
+ """Enhance patterns for more variety"""
276
+ if 'hit_objects' not in beatmap_data:
277
+ return beatmap_data
278
+
279
+ # Detect repetitive patterns
280
+ repetitive_sections = self._detect_repetitive_patterns(beatmap_data)
281
+
282
+ # Replace with varied patterns
283
+ for section in repetitive_sections:
284
+ beatmap_data = self._vary_pattern(beatmap_data, section)
285
+
286
+ return beatmap_data
287
+
288
+ def _load_pattern_library(self) -> List[Dict]:
289
+ """Load common mapping patterns"""
290
+ return [
291
+ {'name': 'triangle', 'positions': [(0, 0), (100, 0), (50, 86.6)]},
292
+ {'name': 'square', 'positions': [(0, 0), (100, 0), (100, 100), (0, 100)]},
293
+ {'name': 'star', 'positions': [(50, 0), (61, 35), (97, 35), (68, 57), (79, 91), (50, 70), (21, 91), (32, 57), (3, 35), (39, 35)]},
294
+ {'name': 'hexagon', 'positions': [(50, 0), (93, 25), (93, 75), (50, 100), (7, 75), (7, 25)]},
295
+ ]
296
+
297
+ def _detect_repetitive_patterns(self, beatmap_data: Dict) -> List[Tuple[int, int]]:
298
+ """Detect sections with repetitive patterns"""
299
+ repetitive_sections = []
300
+ objects = beatmap_data.get('hit_objects', [])
301
+
302
+ window_size = 8
303
+ for i in range(len(objects) - window_size * 2):
304
+ pattern1 = self._extract_pattern(objects[i:i+window_size])
305
+ pattern2 = self._extract_pattern(objects[i+window_size:i+window_size*2])
306
+
307
+ if self._patterns_similar(pattern1, pattern2):
308
+ repetitive_sections.append((i, i + window_size * 2))
309
+
310
+ return repetitive_sections
311
+
312
+ def _extract_pattern(self, objects: List[Dict]) -> List[Tuple[float, float]]:
313
+ """Extract position pattern from objects"""
314
+ return [obj.get('position', (256, 192)) for obj in objects]
315
+
316
+ def _patterns_similar(self, pattern1: List, pattern2: List, threshold: float = 0.8) -> bool:
317
+ """Check if two patterns are similar"""
318
+ if len(pattern1) != len(pattern2):
319
+ return False
320
+
321
+ distances = []
322
+ for pos1, pos2 in zip(pattern1, pattern2):
323
+ dist = np.sqrt((pos1[0] - pos2[0])**2 + (pos1[1] - pos2[1])**2)
324
+ distances.append(dist)
325
+
326
+ avg_distance = np.mean(distances)
327
+ return avg_distance < 50 # Threshold for similarity
328
+
329
+ def _vary_pattern(self, beatmap_data: Dict, section: Tuple[int, int]) -> Dict:
330
+ """Apply variation to a pattern section"""
331
+ start, end = section
332
+ objects = beatmap_data['hit_objects']
333
+
334
+ # Select a random pattern from library
335
+ pattern = np.random.choice(self.pattern_library)
336
+ pattern_positions = pattern['positions']
337
+
338
+ # Apply pattern with scaling
339
+ section_length = end - start
340
+ for i in range(start, min(end, len(objects))):
341
+ if 'position' in objects[i]:
342
+ pattern_idx = (i - start) % len(pattern_positions)
343
+ base_pos = pattern_positions[pattern_idx]
344
+
345
+ # Scale and translate pattern
346
+ center = (256, 192)
347
+ scale = 2.0
348
+
349
+ new_x = center[0] + base_pos[0] * scale
350
+ new_y = center[1] + base_pos[1] * scale
351
+
352
+ objects[i]['position'] = (
353
+ max(0, min(512, new_x)),
354
+ max(0, min(384, new_y))
355
+ )
356
+
357
+ return beatmap_data
358
+
359
+
360
+ class QualityController:
361
+ """Control quality aspects of beatmaps"""
362
+
363
+ def __init__(self, config: BeatHeritageConfig):
364
+ self.config = config
365
+
366
+ def fix_spacing_issues(self, beatmap_data: Dict) -> Dict:
367
+ """Fix objects that are too close together"""
368
+ if 'hit_objects' not in beatmap_data:
369
+ return beatmap_data
370
+
371
+ objects = beatmap_data['hit_objects']
372
+ min_distance = self.config.min_distance_threshold
373
+
374
+ for i in range(1, len(objects)):
375
+ if 'position' in objects[i] and 'position' in objects[i-1]:
376
+ distance = self._calculate_distance(
377
+ objects[i-1]['position'],
378
+ objects[i]['position']
379
+ )
380
+
381
+ if distance < min_distance:
382
+ # Move object to maintain minimum distance
383
+ direction = self._get_direction(
384
+ objects[i-1]['position'],
385
+ objects[i]['position']
386
+ )
387
+
388
+ objects[i]['position'] = self._move_position(
389
+ objects[i-1]['position'],
390
+ direction,
391
+ min_distance
392
+ )
393
+
394
+ return beatmap_data
395
+
396
+ def fix_overlaps(self, beatmap_data: Dict) -> Dict:
397
+ """Fix overlapping sliders and circles"""
398
+ if 'hit_objects' not in beatmap_data:
399
+ return beatmap_data
400
+
401
+ objects = beatmap_data['hit_objects']
402
+ max_overlap = self.config.max_overlap_ratio
403
+
404
+ for i in range(len(objects)):
405
+ for j in range(i+1, min(i+10, len(objects))): # Check next 10 objects
406
+ if self._objects_overlap(objects[i], objects[j], max_overlap):
407
+ # Adjust position to reduce overlap
408
+ objects[j] = self._adjust_for_overlap(objects[i], objects[j])
409
+
410
+ return beatmap_data
411
+
412
+ def _calculate_distance(self, pos1: Tuple[float, float],
413
+ pos2: Tuple[float, float]) -> float:
414
+ """Calculate Euclidean distance"""
415
+ return np.sqrt((pos1[0] - pos2[0])**2 + (pos1[1] - pos2[1])**2)
416
+
417
+ def _get_direction(self, from_pos: Tuple[float, float],
418
+ to_pos: Tuple[float, float]) -> Tuple[float, float]:
419
+ """Get normalized direction vector"""
420
+ dx = to_pos[0] - from_pos[0]
421
+ dy = to_pos[1] - from_pos[1]
422
+
423
+ length = np.sqrt(dx**2 + dy**2)
424
+ if length < 0.01:
425
+ return (1, 0) # Default right direction
426
+
427
+ return (dx / length, dy / length)
428
+
429
+ def _move_position(self, from_pos: Tuple[float, float],
430
+ direction: Tuple[float, float],
431
+ distance: float) -> Tuple[float, float]:
432
+ """Move position in direction by distance"""
433
+ new_x = from_pos[0] + direction[0] * distance
434
+ new_y = from_pos[1] + direction[1] * distance
435
+
436
+ # Keep within bounds
437
+ return (
438
+ max(0, min(512, new_x)),
439
+ max(0, min(384, new_y))
440
+ )
441
+
442
+ def _objects_overlap(self, obj1: Dict, obj2: Dict, threshold: float) -> bool:
443
+ """Check if two objects overlap beyond threshold"""
444
+ if 'position' not in obj1 or 'position' not in obj2:
445
+ return False
446
+
447
+ distance = self._calculate_distance(obj1['position'], obj2['position'])
448
+
449
+ # Simple overlap check (can be improved for sliders)
450
+ radius = 30 # Approximate circle radius
451
+ overlap = max(0, 2 * radius - distance) / (2 * radius)
452
+
453
+ return overlap > threshold
454
+
455
+ def _adjust_for_overlap(self, obj1: Dict, obj2: Dict) -> Dict:
456
+ """Adjust object position to reduce overlap"""
457
+ if 'position' not in obj1 or 'position' not in obj2:
458
+ return obj2
459
+
460
+ # Move obj2 away from obj1
461
+ direction = self._get_direction(obj1['position'], obj2['position'])
462
+ min_safe_distance = 60 # Minimum safe distance
463
+
464
+ obj2['position'] = self._move_position(
465
+ obj1['position'],
466
+ direction,
467
+ min_safe_distance
468
+ )
469
+
470
+ return obj2
471
+
472
+
473
+ # Export main postprocessor
474
+ __all__ = ['BeatHeritagePostprocessor', 'BeatHeritageConfig']
benchmark_comparison.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ BeatHeritage V1 vs Mapperatorinator V30 Benchmark Script
4
+ Compares performance, quality, and generation characteristics
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import time
10
+ import json
11
+ import argparse
12
+ import subprocess
13
+ import numpy as np
14
+ import pandas as pd
15
+ from pathlib import Path
16
+ from typing import Dict, List, Tuple, Optional
17
+ from datetime import datetime
18
+ import torch
19
+ import matplotlib.pyplot as plt
20
+ import seaborn as sns
21
+ from tqdm import tqdm
22
+ import logging
23
+
24
+ # Setup logging
25
+ logging.basicConfig(
26
+ level=logging.INFO,
27
+ format='%(asctime)s - %(levelname)s - %(message)s'
28
+ )
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class BenchmarkRunner:
33
+ """Run benchmarks comparing BeatHeritage V1 with Mapperatorinator V30"""
34
+
35
+ def __init__(self, output_dir: str = "./benchmark_results"):
36
+ self.output_dir = Path(output_dir)
37
+ self.output_dir.mkdir(parents=True, exist_ok=True)
38
+ self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
39
+ self.results = []
40
+
41
+ def run_inference(self, model_config: str, audio_path: str,
42
+ gamemode: int, difficulty: float) -> Dict:
43
+ """Run inference with specified model and parameters"""
44
+
45
+ output_path = self.output_dir / f"{model_config}_{Path(audio_path).stem}"
46
+ output_path.mkdir(parents=True, exist_ok=True)
47
+
48
+ cmd = [
49
+ 'python', 'inference.py',
50
+ '-cn', model_config,
51
+ f'audio_path={audio_path}',
52
+ f'output_path={str(output_path)}',
53
+ f'gamemode={gamemode}',
54
+ f'difficulty={difficulty}',
55
+ ]
56
+
57
+ # Add model-specific parameters
58
+ if model_config == 'beatheritage_v1':
59
+ cmd.extend([
60
+ 'temperature=0.85',
61
+ 'top_p=0.92',
62
+ 'quality_control.enable_auto_correction=true',
63
+ 'quality_control.enable_flow_optimization=true',
64
+ 'advanced_features.enable_pattern_variety=true',
65
+ ])
66
+ else: # v30
67
+ cmd.extend([
68
+ 'temperature=0.9',
69
+ 'top_p=0.9',
70
+ ])
71
+
72
+ # Measure performance
73
+ start_time = time.time()
74
+ memory_before = self._get_memory_usage()
75
+
76
+ try:
77
+ result = subprocess.run(
78
+ cmd,
79
+ capture_output=True,
80
+ text=True,
81
+ check=True
82
+ )
83
+
84
+ end_time = time.time()
85
+ memory_after = self._get_memory_usage()
86
+
87
+ # Parse output for quality metrics
88
+ output_files = list(output_path.glob('*.osu'))
89
+
90
+ metrics = {
91
+ 'model': model_config,
92
+ 'audio': Path(audio_path).name,
93
+ 'gamemode': gamemode,
94
+ 'difficulty': difficulty,
95
+ 'generation_time': end_time - start_time,
96
+ 'memory_usage': memory_after - memory_before,
97
+ 'success': True,
98
+ 'output_files': len(output_files),
99
+ 'quality_metrics': self._analyze_quality(output_files[0] if output_files else None)
100
+ }
101
+
102
+ except subprocess.CalledProcessError as e:
103
+ logger.error(f"Error running {model_config}: {e}")
104
+ metrics = {
105
+ 'model': model_config,
106
+ 'audio': Path(audio_path).name,
107
+ 'gamemode': gamemode,
108
+ 'difficulty': difficulty,
109
+ 'generation_time': -1,
110
+ 'memory_usage': -1,
111
+ 'success': False,
112
+ 'error': str(e),
113
+ 'output_files': 0,
114
+ 'quality_metrics': {}
115
+ }
116
+
117
+ return metrics
118
+
119
+ def _get_memory_usage(self) -> float:
120
+ """Get current GPU memory usage in MB"""
121
+ if torch.cuda.is_available():
122
+ return torch.cuda.memory_allocated() / 1024**2
123
+ return 0
124
+
125
+ def _analyze_quality(self, osu_file: Optional[Path]) -> Dict:
126
+ """Analyze quality metrics of generated beatmap"""
127
+ if not osu_file or not osu_file.exists():
128
+ return {}
129
+
130
+ metrics = {
131
+ 'object_count': 0,
132
+ 'avg_spacing': 0,
133
+ 'spacing_variance': 0,
134
+ 'pattern_diversity': 0,
135
+ 'flow_score': 0,
136
+ 'difficulty_consistency': 0
137
+ }
138
+
139
+ try:
140
+ with open(osu_file, 'r', encoding='utf-8') as f:
141
+ lines = f.readlines()
142
+
143
+ # Parse hit objects
144
+ hit_objects = []
145
+ in_hit_objects = False
146
+
147
+ for line in lines:
148
+ if '[HitObjects]' in line:
149
+ in_hit_objects = True
150
+ continue
151
+
152
+ if in_hit_objects and line.strip():
153
+ parts = line.strip().split(',')
154
+ if len(parts) >= 2:
155
+ try:
156
+ x, y = int(parts[0]), int(parts[1])
157
+ hit_objects.append((x, y))
158
+ except:
159
+ pass
160
+
161
+ metrics['object_count'] = len(hit_objects)
162
+
163
+ if len(hit_objects) > 1:
164
+ # Calculate spacing metrics
165
+ distances = []
166
+ for i in range(1, len(hit_objects)):
167
+ dist = np.sqrt(
168
+ (hit_objects[i][0] - hit_objects[i-1][0])**2 +
169
+ (hit_objects[i][1] - hit_objects[i-1][1])**2
170
+ )
171
+ distances.append(dist)
172
+
173
+ metrics['avg_spacing'] = np.mean(distances)
174
+ metrics['spacing_variance'] = np.var(distances)
175
+
176
+ # Pattern diversity (entropy of distance distribution)
177
+ hist, _ = np.histogram(distances, bins=10)
178
+ hist = hist / hist.sum()
179
+ entropy = -np.sum(hist * np.log(hist + 1e-10))
180
+ metrics['pattern_diversity'] = entropy
181
+
182
+ # Flow score (based on angle changes)
183
+ if len(hit_objects) > 2:
184
+ angles = []
185
+ for i in range(2, len(hit_objects)):
186
+ angle = self._calculate_angle(
187
+ hit_objects[i-2],
188
+ hit_objects[i-1],
189
+ hit_objects[i]
190
+ )
191
+ angles.append(angle)
192
+
193
+ # Lower angle variance = better flow
194
+ metrics['flow_score'] = 1.0 / (1.0 + np.var(angles) / 100)
195
+
196
+ # Difficulty consistency
197
+ chunk_size = max(10, len(distances) // 10)
198
+ chunk_variances = []
199
+ for i in range(0, len(distances), chunk_size):
200
+ chunk = distances[i:i+chunk_size]
201
+ if chunk:
202
+ chunk_variances.append(np.var(chunk))
203
+
204
+ if chunk_variances:
205
+ metrics['difficulty_consistency'] = 1.0 / (1.0 + np.var(chunk_variances))
206
+
207
+ except Exception as e:
208
+ logger.error(f"Error analyzing quality: {e}")
209
+
210
+ return metrics
211
+
212
+ def _calculate_angle(self, p1: Tuple, p2: Tuple, p3: Tuple) -> float:
213
+ """Calculate angle between three points"""
214
+ v1 = (p2[0] - p1[0], p2[1] - p1[1])
215
+ v2 = (p3[0] - p2[0], p3[1] - p2[1])
216
+
217
+ angle1 = np.arctan2(v1[1], v1[0])
218
+ angle2 = np.arctan2(v2[1], v2[0])
219
+
220
+ angle_diff = angle2 - angle1
221
+ # Normalize to [-pi, pi]
222
+ while angle_diff > np.pi:
223
+ angle_diff -= 2 * np.pi
224
+ while angle_diff < -np.pi:
225
+ angle_diff += 2 * np.pi
226
+
227
+ return abs(angle_diff)
228
+
229
+ def run_benchmark_suite(self, test_audio_files: List[str]):
230
+ """Run complete benchmark suite"""
231
+
232
+ models = ['beatheritage_v1', 'v30']
233
+ gamemodes = [0, 1, 2, 3] # All gamemodes
234
+ difficulties = [3.0, 5.5, 7.5] # Easy, Normal, Hard
235
+
236
+ total_tests = len(test_audio_files) * len(models) * len(gamemodes) * len(difficulties)
237
+
238
+ with tqdm(total=total_tests, desc="Running benchmarks") as pbar:
239
+ for audio_file in test_audio_files:
240
+ for gamemode in gamemodes:
241
+ for difficulty in difficulties:
242
+ for model in models:
243
+ logger.info(f"Testing {model} on {audio_file} "
244
+ f"(GM:{gamemode}, Diff:{difficulty})")
245
+
246
+ result = self.run_inference(
247
+ model, audio_file, gamemode, difficulty
248
+ )
249
+ self.results.append(result)
250
+ pbar.update(1)
251
+
252
+ # Save intermediate results
253
+ self._save_results()
254
+
255
+ def _save_results(self):
256
+ """Save benchmark results to JSON and CSV"""
257
+ # Save as JSON
258
+ json_path = self.output_dir / f"benchmark_results_{self.timestamp}.json"
259
+ with open(json_path, 'w') as f:
260
+ json.dump(self.results, f, indent=2)
261
+
262
+ # Save as CSV for analysis
263
+ df = pd.DataFrame(self.results)
264
+ csv_path = self.output_dir / f"benchmark_results_{self.timestamp}.csv"
265
+ df.to_csv(csv_path, index=False)
266
+
267
+ logger.info(f"Results saved to {json_path} and {csv_path}")
268
+
269
+ def generate_report(self):
270
+ """Generate comprehensive benchmark report with visualizations"""
271
+
272
+ if not self.results:
273
+ logger.error("No results to generate report")
274
+ return
275
+
276
+ df = pd.DataFrame(self.results)
277
+
278
+ # Create visualizations
279
+ fig = plt.figure(figsize=(20, 12))
280
+
281
+ # 1. Generation Time Comparison
282
+ ax1 = plt.subplot(2, 3, 1)
283
+ successful_df = df[df['success'] == True]
284
+ if not successful_df.empty:
285
+ sns.boxplot(data=successful_df, x='model', y='generation_time', ax=ax1)
286
+ ax1.set_title('Generation Time Comparison')
287
+ ax1.set_ylabel('Time (seconds)')
288
+
289
+ # 2. Memory Usage Comparison
290
+ ax2 = plt.subplot(2, 3, 2)
291
+ if not successful_df.empty:
292
+ sns.boxplot(data=successful_df, x='model', y='memory_usage', ax=ax2)
293
+ ax2.set_title('Memory Usage Comparison')
294
+ ax2.set_ylabel('Memory (MB)')
295
+
296
+ # 3. Success Rate
297
+ ax3 = plt.subplot(2, 3, 3)
298
+ success_rates = df.groupby('model')['success'].mean() * 100
299
+ success_rates.plot(kind='bar', ax=ax3)
300
+ ax3.set_title('Success Rate (%)')
301
+ ax3.set_ylabel('Success Rate')
302
+ ax3.set_ylim(0, 105)
303
+
304
+ # 4. Quality Metrics Comparison
305
+ if not successful_df.empty and 'quality_metrics' in successful_df.columns:
306
+ # Extract quality metrics
307
+ quality_data = []
308
+ for _, row in successful_df.iterrows():
309
+ if row['quality_metrics']:
310
+ quality_data.append({
311
+ 'model': row['model'],
312
+ 'pattern_diversity': row['quality_metrics'].get('pattern_diversity', 0),
313
+ 'flow_score': row['quality_metrics'].get('flow_score', 0),
314
+ 'difficulty_consistency': row['quality_metrics'].get('difficulty_consistency', 0)
315
+ })
316
+
317
+ if quality_data:
318
+ quality_df = pd.DataFrame(quality_data)
319
+
320
+ # Pattern Diversity
321
+ ax4 = plt.subplot(2, 3, 4)
322
+ if 'pattern_diversity' in quality_df.columns:
323
+ sns.boxplot(data=quality_df, x='model', y='pattern_diversity', ax=ax4)
324
+ ax4.set_title('Pattern Diversity Score')
325
+
326
+ # Flow Score
327
+ ax5 = plt.subplot(2, 3, 5)
328
+ if 'flow_score' in quality_df.columns:
329
+ sns.boxplot(data=quality_df, x='model', y='flow_score', ax=ax5)
330
+ ax5.set_title('Flow Quality Score')
331
+
332
+ # Difficulty Consistency
333
+ ax6 = plt.subplot(2, 3, 6)
334
+ if 'difficulty_consistency' in quality_df.columns:
335
+ sns.boxplot(data=quality_df, x='model', y='difficulty_consistency', ax=ax6)
336
+ ax6.set_title('Difficulty Consistency Score')
337
+
338
+ plt.suptitle('BeatHeritage V1 vs Mapperatorinator V30 Benchmark Report', fontsize=16)
339
+ plt.tight_layout()
340
+
341
+ # Save plot
342
+ plot_path = self.output_dir / f"benchmark_report_{self.timestamp}.png"
343
+ plt.savefig(plot_path, dpi=150, bbox_inches='tight')
344
+ plt.show()
345
+
346
+ # Generate text summary
347
+ summary = self._generate_text_summary(df)
348
+ summary_path = self.output_dir / f"benchmark_summary_{self.timestamp}.txt"
349
+ with open(summary_path, 'w') as f:
350
+ f.write(summary)
351
+
352
+ logger.info(f"Report generated: {plot_path} and {summary_path}")
353
+
354
+ def _generate_text_summary(self, df: pd.DataFrame) -> str:
355
+ """Generate text summary of benchmark results"""
356
+
357
+ summary = []
358
+ summary.append("=" * 80)
359
+ summary.append("BEATHERITAGE V1 VS MAPPERATORINATOR V30 BENCHMARK SUMMARY")
360
+ summary.append("=" * 80)
361
+ summary.append(f"Timestamp: {self.timestamp}")
362
+ summary.append(f"Total Tests: {len(df)}")
363
+ summary.append("")
364
+
365
+ for model in df['model'].unique():
366
+ model_df = df[df['model'] == model]
367
+ successful_df = model_df[model_df['success'] == True]
368
+
369
+ summary.append(f"\n{model.upper()}")
370
+ summary.append("-" * 40)
371
+ summary.append(f"Success Rate: {model_df['success'].mean()*100:.1f}%")
372
+
373
+ if not successful_df.empty:
374
+ summary.append(f"Avg Generation Time: {successful_df['generation_time'].mean():.2f}s")
375
+ summary.append(f"Avg Memory Usage: {successful_df['memory_usage'].mean():.1f}MB")
376
+
377
+ # Quality metrics
378
+ quality_metrics = []
379
+ for _, row in successful_df.iterrows():
380
+ if row['quality_metrics']:
381
+ quality_metrics.append(row['quality_metrics'])
382
+
383
+ if quality_metrics:
384
+ avg_diversity = np.mean([m.get('pattern_diversity', 0) for m in quality_metrics])
385
+ avg_flow = np.mean([m.get('flow_score', 0) for m in quality_metrics])
386
+ avg_consistency = np.mean([m.get('difficulty_consistency', 0) for m in quality_metrics])
387
+
388
+ summary.append(f"Avg Pattern Diversity: {avg_diversity:.3f}")
389
+ summary.append(f"Avg Flow Score: {avg_flow:.3f}")
390
+ summary.append(f"Avg Difficulty Consistency: {avg_consistency:.3f}")
391
+
392
+ # Winner determination
393
+ summary.append("\n" + "=" * 80)
394
+ summary.append("WINNER ANALYSIS")
395
+ summary.append("=" * 80)
396
+
397
+ if len(df['model'].unique()) == 2:
398
+ model1, model2 = df['model'].unique()
399
+
400
+ # Compare metrics
401
+ metrics_comparison = []
402
+
403
+ for metric in ['generation_time', 'memory_usage']:
404
+ m1_avg = df[df['model'] == model1][metric].mean()
405
+ m2_avg = df[df['model'] == model2][metric].mean()
406
+
407
+ if m1_avg < m2_avg:
408
+ winner = model1
409
+ improvement = ((m2_avg - m1_avg) / m2_avg) * 100
410
+ else:
411
+ winner = model2
412
+ improvement = ((m1_avg - m2_avg) / m1_avg) * 100
413
+
414
+ metrics_comparison.append(
415
+ f"{metric}: {winner} ({improvement:.1f}% better)"
416
+ )
417
+
418
+ for comp in metrics_comparison:
419
+ summary.append(comp)
420
+
421
+ return "\n".join(summary)
422
+
423
+
424
+ def main():
425
+ parser = argparse.ArgumentParser(description='Benchmark BeatHeritage V1 vs V30')
426
+ parser.add_argument(
427
+ '--audio-dir',
428
+ type=str,
429
+ default='./test_audio',
430
+ help='Directory containing test audio files'
431
+ )
432
+ parser.add_argument(
433
+ '--output-dir',
434
+ type=str,
435
+ default='./benchmark_results',
436
+ help='Directory to save benchmark results'
437
+ )
438
+ parser.add_argument(
439
+ '--quick-test',
440
+ action='store_true',
441
+ help='Run quick test with limited parameters'
442
+ )
443
+
444
+ args = parser.parse_args()
445
+
446
+ # Get test audio files
447
+ audio_dir = Path(args.audio_dir)
448
+ if audio_dir.exists():
449
+ audio_files = list(audio_dir.glob('*.mp3')) + list(audio_dir.glob('*.ogg'))
450
+ else:
451
+ # Use demo files
452
+ logger.warning(f"Audio directory {audio_dir} not found, using demo files")
453
+ audio_files = ['demo.mp3'] # Fallback to demo
454
+
455
+ if args.quick_test:
456
+ # Quick test with limited parameters
457
+ audio_files = audio_files[:1]
458
+ logger.info("Running quick test with 1 audio file")
459
+
460
+ # Run benchmarks
461
+ runner = BenchmarkRunner(args.output_dir)
462
+ runner.run_benchmark_suite([str(f) for f in audio_files])
463
+ runner.generate_report()
464
+
465
+ logger.info("Benchmark complete!")
466
+
467
+
468
+ if __name__ == "__main__":
469
+ main()
calc_fid.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import traceback
5
+ from datetime import timedelta
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import hydra
10
+ import numpy as np
11
+ import torch
12
+ from scipy import linalg
13
+ from slider import Beatmap, Circle, Slider, Spinner, HoldNote
14
+ from torch.utils.data import DataLoader
15
+ from tqdm import tqdm
16
+
17
+ from classifier.classify import ExampleDataset
18
+ from classifier.libs.model.model import OsuClassifierOutput
19
+ from classifier.libs.utils import load_ckpt
20
+ from config import FidConfig
21
+ from inference import prepare_args, load_diff_model, generate, load_model
22
+ from osuT5.osuT5.dataset.data_utils import load_audio_file, load_mmrs_metadata, filter_mmrs_metadata
23
+ from osuT5.osuT5.inference import generation_config_from_beatmap, beatmap_config_from_beatmap
24
+ from osuT5.osuT5.tokenizer import ContextType
25
+ from multiprocessing import Manager, Process
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def get_beatmap_paths(args: FidConfig) -> list[Path]:
31
+ """Get all beatmap paths (.osu) from the dataset directory."""
32
+ dataset_path = Path(args.dataset_path)
33
+
34
+ if args.dataset_type == "mmrs":
35
+ metadata = load_mmrs_metadata(dataset_path)
36
+ filtered_metadata = filter_mmrs_metadata(
37
+ metadata,
38
+ start=args.dataset_start,
39
+ end=args.dataset_end,
40
+ gamemodes=args.gamemodes,
41
+ )
42
+ beatmap_files = [dataset_path / "data" / item["BeatmapSetFolder"] / item["BeatmapFile"] for _, item in filtered_metadata.iterrows()]
43
+ elif args.dataset_type == "ors":
44
+ beatmap_files = []
45
+ track_names = ["Track" + str(i).zfill(5) for i in range(args.dataset_start, args.dataset_end)]
46
+ for track_name in track_names:
47
+ for beatmap_file in (dataset_path / track_name / "beatmaps").iterdir():
48
+ beatmap_files.append(dataset_path / track_name / "beatmaps" / beatmap_file.name)
49
+ else:
50
+ raise ValueError(f"Unknown dataset type: {args.dataset_type}")
51
+
52
+ return beatmap_files
53
+
54
+
55
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
56
+ """Numpy implementation of the Frechet Distance.
57
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
58
+ and X_2 ~ N(mu_2, C_2) is
59
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
60
+
61
+ Stable version by Dougal J. Sutherland.
62
+
63
+ Params:
64
+ -- mu1 : Numpy array containing the activations of a layer of the
65
+ inception net (like returned by the function 'get_predictions')
66
+ for generated samples.
67
+ -- mu2 : The sample mean over activations, precalculated on an
68
+ representative data set.
69
+ -- sigma1: The covariance matrix over activations for generated samples.
70
+ -- sigma2: The covariance matrix over activations, precalculated on an
71
+ representative data set.
72
+
73
+ Returns:
74
+ -- : The Frechet Distance.
75
+ """
76
+
77
+ mu1 = np.atleast_1d(mu1)
78
+ mu2 = np.atleast_1d(mu2)
79
+
80
+ sigma1 = np.atleast_2d(sigma1)
81
+ sigma2 = np.atleast_2d(sigma2)
82
+
83
+ assert (
84
+ mu1.shape == mu2.shape
85
+ ), "Training and test mean vectors have different lengths"
86
+ assert (
87
+ sigma1.shape == sigma2.shape
88
+ ), "Training and test covariances have different dimensions"
89
+
90
+ diff = mu1 - mu2
91
+
92
+ # Product might be almost singular
93
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
94
+ if not np.isfinite(covmean).all():
95
+ msg = (
96
+ "fid calculation produces singular product; "
97
+ "adding %s to diagonal of cov estimates"
98
+ ) % eps
99
+ logger.warning(msg)
100
+ offset = np.eye(sigma1.shape[0]) * eps
101
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
102
+
103
+ # Numerical error might give slight imaginary component
104
+ if np.iscomplexobj(covmean):
105
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
106
+ m = np.max(np.abs(covmean.imag))
107
+ raise ValueError("Imaginary component {}".format(m))
108
+ covmean = covmean.real
109
+
110
+ tr_covmean = np.trace(covmean)
111
+
112
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
113
+
114
+
115
+ def add_to_dict(source_dict, target_dict):
116
+ for key, value in source_dict.items():
117
+ if key not in target_dict:
118
+ target_dict[key] = value
119
+ else:
120
+ target_dict[key] += value
121
+
122
+
123
+ def calculate_rhythm_stats(real_rhythm, generated_rhythm):
124
+ # Rhythm is a set of timestamps for each beat
125
+ # Calculate number of true positives, false positives, and false negatives within a leniency of 10 ms
126
+ leniency = 10
127
+ true_positives = 0
128
+ false_positives = 0
129
+ false_negatives = 0
130
+ for real_beat in real_rhythm:
131
+ if any(abs(real_beat - gen_beat) <= leniency for gen_beat in generated_rhythm):
132
+ true_positives += 1
133
+ else:
134
+ false_negatives += 1
135
+
136
+ for gen_beat in generated_rhythm:
137
+ if not any(abs(gen_beat - real_beat) <= leniency for real_beat in real_rhythm):
138
+ false_positives += 1
139
+
140
+ return {
141
+ "true_positives": true_positives,
142
+ "false_positives": false_positives,
143
+ "false_negatives": false_negatives,
144
+ }
145
+
146
+
147
+ def calculate_precision(rhythm_stats):
148
+ true_positives = rhythm_stats["true_positives"]
149
+ false_positives = rhythm_stats["false_positives"]
150
+ if true_positives + false_positives == 0:
151
+ return 0.0
152
+ return true_positives / (true_positives + false_positives)
153
+
154
+
155
+ def calculate_recall(rhythm_stats):
156
+ true_positives = rhythm_stats["true_positives"]
157
+ false_negatives = rhythm_stats["false_negatives"]
158
+ if true_positives + false_negatives == 0:
159
+ return 0.0
160
+ return true_positives / (true_positives + false_negatives)
161
+
162
+
163
+ def calculate_f1(rhythm_stats):
164
+ precision = calculate_precision(rhythm_stats)
165
+ recall = calculate_recall(rhythm_stats)
166
+ if precision + recall == 0:
167
+ return 0.0
168
+ return 2 * (precision * recall) / (precision + recall)
169
+
170
+
171
+ def get_rhythm(beatmap, passive=False):
172
+ # Extract the rhythm from the beatmap
173
+ # Active rhythm includes only circles, slider heads, and hold note heads
174
+ # Passive rhythm also includes slider tails, slider repeats, and spinners tails
175
+ rhythm = set()
176
+ for hit_object in beatmap.hit_objects(stacking=False):
177
+ if isinstance(hit_object, Circle):
178
+ rhythm.add(int(hit_object.time.total_seconds() * 1000 + 1e-5))
179
+ elif isinstance(hit_object, Slider):
180
+ duration: timedelta = (hit_object.end_time - hit_object.time) / hit_object.repeat
181
+ rhythm.add(int(hit_object.time.total_seconds() * 1000 + 1e-5))
182
+ if passive:
183
+ for i in range(hit_object.repeat):
184
+ rhythm.add(int((hit_object.time + duration * (i + 1)).total_seconds() * 1000 + 1e-5))
185
+ elif isinstance(hit_object, Spinner):
186
+ if passive:
187
+ rhythm.add(int(hit_object.end_time.total_seconds() * 1000 + 1e-5))
188
+ elif isinstance(hit_object, HoldNote):
189
+ rhythm.add(int(hit_object.time.total_seconds() * 1000 + 1e-5))
190
+
191
+ return rhythm
192
+
193
+
194
+ def generate_beatmaps(beatmap_paths, fid_args: FidConfig, return_dict, idx):
195
+ args = fid_args.inference
196
+ args.device = fid_args.device
197
+ torch.set_grad_enabled(False)
198
+ torch.set_float32_matmul_precision('high')
199
+
200
+ model, tokenizer, diff_model, diff_tokenizer, refine_model = None, None, None, None, None
201
+ model, tokenizer = load_model(args.model_path, args.train, args.device, args.max_batch_size, args.use_server, args.precision)
202
+
203
+ if args.compile:
204
+ model.transformer.forward = torch.compile(model.transformer.forward, mode="reduce-overhead", fullgraph=True)
205
+
206
+ if args.generate_positions:
207
+ diff_model, diff_tokenizer = load_diff_model(args.diff_ckpt, args.diffusion, args.device)
208
+
209
+ if os.path.exists(args.diff_refine_ckpt):
210
+ refine_model = load_diff_model(args.diff_refine_ckpt, args.diffusion, args.device)[0]
211
+
212
+ if args.compile:
213
+ diff_model.forward = torch.compile(diff_model.forward, mode="reduce-overhead", fullgraph=False)
214
+
215
+ for beatmap_path in tqdm(beatmap_paths, desc=f"Process {idx}"):
216
+ try:
217
+ beatmap = Beatmap.from_path(beatmap_path)
218
+ output_path = Path("generated") / beatmap_path.stem
219
+
220
+ if fid_args.dataset_type == "ors":
221
+ audio_path = beatmap_path.parents[1] / list(beatmap_path.parents[1].glob('audio.*'))[0]
222
+ else:
223
+ audio_path = beatmap_path.parent / beatmap.audio_filename
224
+
225
+ if fid_args.skip_generation or (output_path.exists() and len(list(output_path.glob("*.osu"))) > 0):
226
+ if not output_path.exists() or len(list(output_path.glob("*.osu"))) == 0:
227
+ raise FileNotFoundError(f"Generated beatmap not found in {output_path}")
228
+ print(f"Skipping {beatmap_path.stem} as it already exists")
229
+ else:
230
+ if ContextType.GD in args.in_context:
231
+ other_beatmaps = [k for k in beatmap_path.parent.glob("*.osu") if k != beatmap_path]
232
+ if len(other_beatmaps) == 0:
233
+ continue
234
+ other_beatmap_path = random.choice(other_beatmaps)
235
+ else:
236
+ other_beatmap_path = beatmap_path
237
+
238
+ generation_config = generation_config_from_beatmap(beatmap, tokenizer)
239
+ beatmap_config = beatmap_config_from_beatmap(beatmap)
240
+ beatmap_config.version = args.version
241
+
242
+ if args.year is not None:
243
+ generation_config.year = args.year
244
+
245
+ result = generate(
246
+ args,
247
+ audio_path=audio_path,
248
+ beatmap_path=other_beatmap_path,
249
+ output_path=output_path,
250
+ generation_config=generation_config,
251
+ beatmap_config=beatmap_config,
252
+ model=model,
253
+ tokenizer=tokenizer,
254
+ diff_model=diff_model,
255
+ diff_tokenizer=diff_tokenizer,
256
+ refine_model=refine_model,
257
+ verbose=False,
258
+ )[0]
259
+ generated_beatmap = Beatmap.parse(result)
260
+ print(beatmap_path, "Generated %s hit objects" % len(generated_beatmap.hit_objects(stacking=False)))
261
+ except Exception as e:
262
+ print(f"Error processing {beatmap_path}: {e}")
263
+ traceback.print_exc()
264
+ finally:
265
+ torch.cuda.empty_cache() # Clear any cached memory
266
+
267
+
268
+ def calculate_metrics(args: FidConfig, beatmap_paths: list[Path]):
269
+ print("Calculating metrics...")
270
+
271
+ classifier_model, classifier_args, classifier_tokenizer = None, None, None
272
+ if args.fid:
273
+ classifier_model, classifier_args, classifier_tokenizer = load_ckpt(args.classifier_ckpt)
274
+
275
+ if args.compile:
276
+ classifier_model.model.transformer.forward = torch.compile(classifier_model.model.transformer.forward,
277
+ mode="reduce-overhead", fullgraph=False)
278
+
279
+ real_features = []
280
+ generated_features = []
281
+ active_rhythm_stats = {}
282
+ passive_rhythm_stats = {}
283
+
284
+ for beatmap_path in tqdm(beatmap_paths, desc=f"Metrics"):
285
+ try:
286
+ beatmap = Beatmap.from_path(beatmap_path)
287
+ generated_path = Path("generated") / beatmap_path.stem
288
+
289
+ if args.dataset_type == "ors":
290
+ audio_path = beatmap_path.parents[1] / list(beatmap_path.parents[1].glob('audio.*'))[0]
291
+ else:
292
+ audio_path = beatmap_path.parent / beatmap.audio_filename
293
+
294
+ if generated_path.exists() and len(list(generated_path.glob("*.osu"))) > 0:
295
+ generated_beatmap = Beatmap.from_path(list(generated_path.glob("*.osu"))[0])
296
+ else:
297
+ logger.warning(f"Skipping {beatmap_path.stem} as no generated beatmap found")
298
+ continue
299
+
300
+ if args.fid:
301
+ # Calculate feature vectors for real and generated beatmaps
302
+ sample_rate = classifier_args.data.sample_rate
303
+ audio = load_audio_file(audio_path, sample_rate, normalize=args.inference.train.data.normalize_audio)
304
+
305
+ for example in DataLoader(
306
+ ExampleDataset(beatmap, audio, classifier_args, classifier_tokenizer, args.device),
307
+ batch_size=args.classifier_batch_size):
308
+ classifier_result: OsuClassifierOutput = classifier_model(**example)
309
+ features = classifier_result.feature_vector
310
+ real_features.append(features.cpu().numpy())
311
+
312
+ for example in DataLoader(
313
+ ExampleDataset(generated_beatmap, audio, classifier_args, classifier_tokenizer, args.device),
314
+ batch_size=args.classifier_batch_size):
315
+ classifier_result: OsuClassifierOutput = classifier_model(**example)
316
+ features = classifier_result.feature_vector
317
+ generated_features.append(features.cpu().numpy())
318
+
319
+ if args.rhythm_stats:
320
+ # Calculate rhythm stats
321
+ real_active_rhythm = get_rhythm(beatmap, passive=False)
322
+ generated_active_rhythm = get_rhythm(generated_beatmap, passive=False)
323
+ add_to_dict(calculate_rhythm_stats(real_active_rhythm, generated_active_rhythm), active_rhythm_stats)
324
+
325
+ real_passive_rhythm = get_rhythm(beatmap, passive=True)
326
+ generated_passive_rhythm = get_rhythm(generated_beatmap, passive=True)
327
+ add_to_dict(calculate_rhythm_stats(real_passive_rhythm, generated_passive_rhythm), passive_rhythm_stats)
328
+ except Exception as e:
329
+ print(f"Error processing {beatmap_path}: {e}")
330
+ traceback.print_exc()
331
+ finally:
332
+ torch.cuda.empty_cache() # Clear any cached memory
333
+
334
+ if args.fid:
335
+ # Calculate FID
336
+ real_features = np.concatenate(real_features, axis=0)
337
+ generated_features = np.concatenate(generated_features, axis=0)
338
+ m1, s1 = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
339
+ m2, s2 = np.mean(generated_features, axis=0), np.cov(generated_features, rowvar=False)
340
+ fid = calculate_frechet_distance(m1, s1, m2, s2)
341
+
342
+ logger.info(f"FID: {fid}")
343
+
344
+ if args.rhythm_stats:
345
+ # Calculate rhythm precision, recall, and F1 score
346
+ active_precision = calculate_precision(active_rhythm_stats)
347
+ active_recall = calculate_recall(active_rhythm_stats)
348
+ active_f1 = calculate_f1(active_rhythm_stats)
349
+ passive_precision = calculate_precision(passive_rhythm_stats)
350
+ passive_recall = calculate_recall(passive_rhythm_stats)
351
+ passive_f1 = calculate_f1(passive_rhythm_stats)
352
+ logger.info(f"Active Rhythm Precision: {active_precision}")
353
+ logger.info(f"Active Rhythm Recall: {active_recall}")
354
+ logger.info(f"Active Rhythm F1: {active_f1}")
355
+ logger.info(f"Passive Rhythm Precision: {passive_precision}")
356
+ logger.info(f"Passive Rhythm Recall: {passive_recall}")
357
+ logger.info(f"Passive Rhythm F1: {passive_f1}")
358
+
359
+
360
+ def test_training_set_overlap(beatmap_paths: list[Path], training_set_ids_path: Optional[str]):
361
+ if training_set_ids_path is None:
362
+ return
363
+
364
+ if not os.path.exists(training_set_ids_path):
365
+ logger.error(f"Training set IDs file {training_set_ids_path} does not exist.")
366
+ return
367
+
368
+ with open(training_set_ids_path, "r") as f:
369
+ training_set_ids = set(int(line.strip()) for line in f)
370
+
371
+ in_set = 0
372
+ out_set = 0
373
+ for path in tqdm(beatmap_paths):
374
+ beatmap = Beatmap.from_path(path)
375
+ if beatmap.beatmap_id in training_set_ids:
376
+ in_set += 1
377
+ else:
378
+ out_set += 1
379
+ logger.info(f"In training set: {in_set}, Not in training set: {out_set}, Total: {len(beatmap_paths)}, Ratio: {in_set / (in_set + out_set):.2f}")
380
+
381
+
382
+ @hydra.main(config_path="configs", config_name="calc_fid", version_base="1.1")
383
+ def main(args: FidConfig):
384
+ prepare_args(args)
385
+
386
+ # Fix inference model path
387
+ if args.inference.model_path.startswith("./"):
388
+ args.inference.model_path = os.path.join(Path(__file__).parent, args.inference.model_path[2:])
389
+
390
+ beatmap_paths = get_beatmap_paths(args)
391
+
392
+ test_training_set_overlap(beatmap_paths, args.training_set_ids_path)
393
+
394
+ if not args.skip_generation:
395
+ # Assign beatmaps to processes in a round-robin fashion
396
+ num_processes = args.num_processes
397
+ chunks = [[] for _ in range(num_processes)]
398
+ for i, path in enumerate(beatmap_paths):
399
+ chunks[i % num_processes].append(path)
400
+
401
+ manager = Manager()
402
+ return_dict = manager.dict()
403
+ processes = []
404
+
405
+ for i in range(num_processes):
406
+ p = Process(target=generate_beatmaps, args=(chunks[i], args, return_dict, i))
407
+ processes.append(p)
408
+ p.start()
409
+
410
+ for p in processes:
411
+ p.join()
412
+
413
+ calculate_metrics(args, beatmap_paths)
414
+
415
+
416
+ if __name__ == "__main__":
417
+ main()
classifier/README.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mapper Classifier
2
+
3
+ Try the model [here](https://colab.research.google.com/github/OliBomby/Mapperatorinator/blob/main/colab/classifier_classify.ipynb).
4
+
5
+ Mapper Classifier is a model that predicts which osu! standard ranked mapper mapped a given beatmap.
6
+
7
+ This model is built using transfer learning on the Mapperatorinator V22 model.
8
+ It achieves a top-1 validation accuracy of 12.5% on a random sample of ranked beatmaps and recognizes 3,731 unique mappers.
9
+ To make its predictions, the model analyzes an 8-second segment of beatmap.
10
+
11
+ The purpose of this classifier is actually to calculate high-level feature vectors for beatmaps, which can be used to calculate the similarity between generated beatmaps and real beatmaps.
12
+ This is a technique often used to assess the quality of image generation models with the [Fréchet Inception Distance](https://arxiv.org/abs/1706.08500).
13
+ However, in my testing I found that the computed FID scores for beatmap generation models were not very close to the actual quality of the generated beatmaps.
14
+ This classifier might not be able to recognize all the necessary features to accurately assess the quality of a beatmap, but it's a start.
15
+
16
+ ## Usage
17
+
18
+ Run `classify.py` with the path to the beatmap you want to classify and the time in seconds of the segment you want to use to classify the beatmap.
19
+ ```shell
20
+ python classify.py beatmap_path="'...\Songs\1790119 THE ORAL CIGARETTES - ReI\THE ORAL CIGARETTES - ReI (Sotarks) [Cataclysm.].osu'" time=60
21
+ ```
22
+
23
+ ```
24
+ Mapper: Sotarks (4452992) with confidence: 9.760356903076172
25
+ Mapper: Sajinn (13513687) with confidence: 6.975161075592041
26
+ Mapper: kowari (5404892) with confidence: 6.800069332122803
27
+ Mapper: Haruto (3772301) with confidence: 6.077754020690918
28
+ Mapper: Kalibe (3376777) with confidence: 5.894346237182617
29
+ Mapper: iljaaz (8501291) with confidence: 5.873990535736084
30
+ Mapper: tomadoi (5712451) with confidence: 5.817874431610107
31
+ Mapper: Nao Tomori (5364763) with confidence: 5.144880294799805
32
+ Mapper: Kujinn (3723568) with confidence: 5.082106590270996
33
+ ...
34
+ ```
classifier/classify.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ import numpy.typing as npt
4
+
5
+ import hydra
6
+ import torch
7
+ from omegaconf import DictConfig
8
+ from slider import Beatmap
9
+ from torch.utils.data import IterableDataset
10
+
11
+ from classifier.libs.dataset import OsuParser
12
+ from classifier.libs.dataset.data_utils import load_audio_file
13
+ from classifier.libs.dataset.ors_dataset import STEPS_PER_MILLISECOND
14
+ from classifier.libs.model.model import OsuClassifierOutput
15
+ from classifier.libs.tokenizer import Tokenizer, Event, EventType
16
+ from classifier.libs.utils import load_ckpt
17
+
18
+
19
+ def iterate_examples(
20
+ beatmap: Beatmap,
21
+ audio: npt.NDArray,
22
+ model_args: DictConfig,
23
+ tokenizer: Tokenizer,
24
+ device: torch.device
25
+ ):
26
+ frame_seq_len = model_args.data.src_seq_len - 1
27
+ frame_size = model_args.data.hop_length
28
+ sample_rate = model_args.data.sample_rate
29
+ samples_per_sequence = frame_seq_len * frame_size
30
+
31
+ parser = OsuParser(model_args, tokenizer)
32
+ events, event_times = parser.parse(beatmap)
33
+
34
+ for sample in range(0, len(audio) - samples_per_sequence, samples_per_sequence):
35
+ example = create_example(events, event_times, audio, sample / sample_rate, model_args, tokenizer, device)
36
+ yield example
37
+
38
+
39
+ class ExampleDataset(IterableDataset):
40
+ def __init__(self, beatmap, audio, classifier_args, classifier_tokenizer, device):
41
+ self.beatmap = beatmap
42
+ self.audio = audio
43
+ self.classifier_args = classifier_args
44
+ self.classifier_tokenizer = classifier_tokenizer
45
+ self.device = device
46
+
47
+ def __iter__(self):
48
+ return iterate_examples(
49
+ self.beatmap,
50
+ self.audio,
51
+ self.classifier_args,
52
+ self.classifier_tokenizer,
53
+ self.device
54
+ )
55
+
56
+
57
+ def create_example(
58
+ events: list[Event],
59
+ event_times: list[float],
60
+ audio: npt.NDArray,
61
+ time: float,
62
+ model_args: DictConfig,
63
+ tokenizer: Tokenizer,
64
+ device: torch.device,
65
+ unsqueeze: bool = False,
66
+ ):
67
+ frame_seq_len = model_args.data.src_seq_len - 1
68
+ frame_size = model_args.data.hop_length
69
+ sample_rate = model_args.data.sample_rate
70
+ samples_per_sequence = frame_seq_len * frame_size
71
+ sequence_duration = samples_per_sequence / sample_rate
72
+
73
+ # Get audio frames
74
+ frame_start = int(time * sample_rate)
75
+ frames = audio[frame_start:frame_start + samples_per_sequence]
76
+ frames = torch.from_numpy(frames).to(torch.float32).to(device)
77
+
78
+ # Get the events between time and time + sequence_duration
79
+ events = [event for event, event_time in zip(events, event_times) if
80
+ time <= event_time / 1000 < time + sequence_duration]
81
+ # Normalize time shifts
82
+ for i, event in enumerate(events):
83
+ if event.type == EventType.TIME_SHIFT:
84
+ events[i] = Event(EventType.TIME_SHIFT, int((event.value - time * 1000) * STEPS_PER_MILLISECOND))
85
+
86
+ # Tokenize the events
87
+ tokens = torch.full((model_args.data.tgt_seq_len,), tokenizer.pad_id, dtype=torch.long)
88
+ for i in range(min(len(events), model_args.data.tgt_seq_len)):
89
+ tokens[i] = tokenizer.encode(events[i])
90
+ tokens = tokens.to(device)
91
+
92
+ if unsqueeze:
93
+ tokens = tokens.unsqueeze(0)
94
+ frames = frames.unsqueeze(0)
95
+
96
+ return {
97
+ "decoder_input_ids": tokens,
98
+ "decoder_attention_mask": tokens != tokenizer.pad_id,
99
+ "frames": frames,
100
+ }
101
+
102
+
103
+ def create_example_from_path(
104
+ beatmap_path: str,
105
+ audio_path: str,
106
+ time: float,
107
+ model_args: DictConfig,
108
+ tokenizer: Tokenizer,
109
+ device: torch.device,
110
+ unsqueeze: bool = False,
111
+ ):
112
+ sample_rate = model_args.data.sample_rate
113
+
114
+ beatmap_path = Path(beatmap_path)
115
+ beatmap = Beatmap.from_path(beatmap_path)
116
+
117
+ # Get audio frames
118
+ if audio_path == '':
119
+ audio_path = beatmap_path.parent / beatmap.audio_filename
120
+
121
+ audio = load_audio_file(audio_path, sample_rate)
122
+
123
+ parser = OsuParser(model_args, tokenizer)
124
+ events, event_times = parser.parse(beatmap)
125
+
126
+ return create_example(events, event_times, audio, time, model_args, tokenizer, device, unsqueeze)
127
+
128
+
129
+ def get_mapper_names(path: str):
130
+ path = Path(path)
131
+
132
+ # Load JSON data from file
133
+ with open(path, 'r') as file:
134
+ data = json.load(file)
135
+
136
+ # Populate beatmap_mapper
137
+ mapper_names = {}
138
+ for item in data:
139
+ if len(item['username']) == 0:
140
+ mapper_name = "Unknown"
141
+ else:
142
+ mapper_name = item['username'][0]
143
+ mapper_names[item['user_id']] = mapper_name
144
+
145
+ return mapper_names
146
+
147
+
148
+ @hydra.main(config_path="configs", config_name="inference", version_base="1.1")
149
+ def main(args: DictConfig):
150
+ torch.set_grad_enabled(False)
151
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
152
+
153
+ model, model_args, tokenizer = load_ckpt(args.checkpoint_path)
154
+ model.eval().to(device)
155
+
156
+ example = create_example_from_path(args.beatmap_path, args.audio_path, args.time, model_args, tokenizer, device, True)
157
+ result: OsuClassifierOutput = model(**example)
158
+ logits = result.logits
159
+
160
+ # Print the top 100 mappers with confidences
161
+ top_k = 100
162
+ top_k_indices = logits[0].topk(top_k).indices
163
+ top_k_confidences = logits[0].topk(top_k).values
164
+
165
+ mapper_idx_id = {idx: ids for ids, idx in tokenizer.mapper_idx.items()}
166
+ mapper_names = get_mapper_names(args.mappers_path)
167
+
168
+ for idx, confidence in zip(top_k_indices, top_k_confidences):
169
+ mapper_id = mapper_idx_id[idx.item()]
170
+ mapper_name = mapper_names.get(mapper_id, "Unknown")
171
+ print(f"Mapper: {mapper_name} ({mapper_id}) with confidence: {confidence.item()}")
172
+
173
+
174
+ if __name__ == "__main__":
175
+ main()
classifier/configs/inference.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compile: true # PyTorch 2.0 optimization
2
+ device: gpu # Training device (cpu/gpu)
3
+ precision: 'no' # Enable mixed precision (no/fp16/bf16/fp8)
4
+ checkpoint_path: 'OliBomby/osu-classifier' # Project checkpoint directory (to resume training)
5
+ beatmap_path: '' # Path to beatmap to classify
6
+ audio_path: '' # Path to audio to classify
7
+ time: 0 # Time to classify
8
+ mappers_path: './/datasets/beatmap_users.json' # Path to mappers dataset
9
+
10
+ hydra:
11
+ job:
12
+ chdir: False
13
+ run:
14
+ dir: ./logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
classifier/configs/model/model.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ input_features: false
2
+ do_style_embed: true
3
+ classifier_proj_size: 256
4
+
5
+ spectrogram:
6
+ sample_rate: 16000
7
+ hop_length: 128
8
+ n_fft: 1024
9
+ n_mels: 388
classifier/configs/model/whisper_base.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model
3
+ - _self_
4
+
5
+ name: 'openai/whisper-base'
6
+ input_features: true
classifier/configs/model/whisper_base_v2.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model
3
+ - _self_
4
+
5
+ name: 'openai/whisper-base'
6
+ input_features: true
7
+ classifier_proj_size: 2048
classifier/configs/model/whisper_small.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model
3
+ - _self_
4
+
5
+ name: 'openai/whisper-small'
6
+ input_features: true
classifier/configs/model/whisper_tiny.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model
3
+ - _self_
4
+
5
+ name: 'openai/whisper-tiny'
6
+ input_features: true
classifier/configs/train.yaml ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compile: true # PyTorch 2.0 optimization
2
+ device: gpu # Training device (cpu/gpu)
3
+ precision: 'bf16-mixed' # Enable mixed precision (no/fp16/bf16/fp8)
4
+ seed: 42 # Project seed
5
+
6
+ checkpoint_path: '' # Project checkpoint directory (to resume training)
7
+ pretrained_path: '' # Path to pretrained model weights (to do transfer learning)
8
+
9
+ data: # Data settings
10
+ train_dataset_path: "/workspace/datasets/ORS16291"
11
+ test_dataset_path: "/workspace/datasets/ORS16291"
12
+ train_dataset_start: 0 # Training dataset start index
13
+ train_dataset_end: 16200 # Training dataset end index
14
+ test_dataset_start: 16200 # Testing/validation dataset start index
15
+ test_dataset_end: 16291 # Testing/validation dataset end index
16
+ src_seq_len: 1024
17
+ tgt_seq_len: 1024
18
+ sample_rate: ${..model.spectrogram.sample_rate}
19
+ hop_length: ${..model.spectrogram.hop_length}
20
+ cycle_length: 16
21
+ per_track: false # Loads all beatmaps in a track sequentially which optimizes audio data loading
22
+ num_classes: 3731 # Number of label classes in the dataset
23
+ timing_random_offset: 0
24
+ min_difficulty: 0 # Minimum difficulty to consider including in the dataset
25
+ mappers_path: "../../../datasets/beatmap_users.json" # Path to file with all beatmap mappers
26
+ add_timing: true # Model beatmap timing
27
+ add_snapping: true # Model hit object snapping
28
+ add_timing_points: false # Model beatmap timing with timing points
29
+ add_hitsounds: true # Model beatmap hitsounds
30
+ add_distances: false # Model hit object distances
31
+ add_positions: true # Model hit object coordinates
32
+ position_precision: 1 # Precision of hit object coordinates
33
+ position_split_axes: true # Split hit object X and Y coordinates into separate tokens
34
+ position_range: [-256, 768, -256, 640] # Range of hit object coordinates
35
+ dt_augment_prob: 0.7 # Probability of augmenting the dataset with DT
36
+ dt_augment_range: [1.25, 1.5] # Range of DT augmentation
37
+ types_first: true # Put the type token at the start of the group before the timeshift token
38
+ augment_flip: false # Augment the dataset with flipped positions
39
+
40
+
41
+ dataloader: # Dataloader settings
42
+ num_workers: 8
43
+
44
+ optim: # Optimizer settings
45
+ name: adamw
46
+ base_lr: 1e-2 # Should be scaled with the number of devices present
47
+ batch_size: 128 # This is the batch size per GPU
48
+ total_steps: 65536
49
+ warmup_steps: 10000
50
+ lr_scheduler: cosine
51
+ weight_decay: 0.0
52
+ grad_clip: 1.0
53
+ grad_acc: 2
54
+ final_cosine: 1e-5
55
+
56
+ eval: # Evaluation settings
57
+ every_steps: 1000
58
+ steps: 500
59
+
60
+ checkpoint: # Checkpoint settings
61
+ every_steps: 5000
62
+
63
+ logging: # Logging settings
64
+ log_with: 'wandb' # Logging service (wandb/tensorboard)
65
+ every_steps: 10
66
+ grad_l2: true
67
+ weights_l2: true
68
+ mode: 'online'
69
+
70
+ profile: # Profiling settings
71
+ do_profile: false
72
+ early_stop: false
73
+ wait: 8
74
+ warmup: 8
75
+ active: 8
76
+ repeat: 1
77
+
78
+ hydra:
79
+ job:
80
+ chdir: True
81
+ run:
82
+ dir: ./logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
classifier/configs/train_v1.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ defaults:
2
+ - train
3
+ - _self_
4
+ - model: whisper_tiny
classifier/configs/train_v2.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train
3
+ - _self_
4
+ - model: whisper_base
5
+
6
+ pretrained_path: "../../../test/ckpt_v22"
7
+
8
+ optim: # Optimizer settings
9
+ base_lr: 1e-4 # Should be scaled with the number of devices present
10
+ batch_size: 64 # This is the batch size per GPU
11
+ total_steps: 32218
12
+ warmup_steps: 2000
13
+ grad_acc: 2
14
+ final_cosine: 1e-5
classifier/configs/train_v3.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - train
3
+ - _self_
4
+ - model: whisper_base_v2
5
+
6
+ pretrained_path: "../../../test/ckpt_v22"
7
+
8
+ data:
9
+ augment_flip: true
10
+
11
+ optim: # Optimizer settings
12
+ base_lr: 1e-3 # Should be scaled with the number of devices present
13
+ batch_size: 128 # This is the batch size per GPU
14
+ total_steps: 65536
15
+ warmup_steps: 2000
16
+ grad_acc: 4
17
+ final_cosine: 1e-5
classifier/count_classes.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+
5
+ def init_mapper_idx(mappers_path):
6
+ """"Indexes beatmap mappers and mapper idx."""
7
+ path = Path(mappers_path)
8
+
9
+ if not path.exists():
10
+ raise ValueError(f"mappers_path {path} not found")
11
+
12
+ # Load JSON data from file
13
+ with open(path, 'r') as file:
14
+ data = json.load(file)
15
+
16
+ # Populate beatmap_mapper
17
+ beatmap_mapper = {}
18
+ for item in data:
19
+ beatmap_mapper[item['id']] = item['user_id']
20
+
21
+ # Get unique user_ids from beatmap_mapper values
22
+ unique_user_ids = list(set(beatmap_mapper.values()))
23
+
24
+ # Create mapper_idx
25
+ mapper_idx = {user_id: idx for idx, user_id in enumerate(unique_user_ids)}
26
+ num_mapper_classes = len(unique_user_ids)
27
+
28
+ return beatmap_mapper, mapper_idx, num_mapper_classes
29
+
30
+
31
+ path = "../datasets/beatmap_users.json"
32
+ beatmap_mapper, mapper_idx, num_mapper_classes = init_mapper_idx(path)
33
+
34
+ print("Number of mapper classes:", num_mapper_classes)
35
+ print("Number of beatmaps:", len(beatmap_mapper))
36
+ # Calculate number of maps per mapper
37
+ maps_per_mapper = {}
38
+ for beatmap_id in beatmap_mapper:
39
+ user_id = beatmap_mapper[beatmap_id]
40
+ if user_id not in maps_per_mapper:
41
+ maps_per_mapper[user_id] = 0
42
+ maps_per_mapper[user_id] += 1
43
+
44
+ # Calculate average maps per mapper class
45
+ average_maps_per_mapper = len(beatmap_mapper) / num_mapper_classes
46
+ print("Average maps per mapper class:", average_maps_per_mapper)
47
+
48
+ # Calculate median maps per mapper class
49
+ median_maps_per_mapper = sorted(maps_per_mapper.values())[num_mapper_classes // 2]
50
+ print("Median maps per mapper class:", median_maps_per_mapper)
51
+
52
+ # Mapper with most number of maps
53
+ max_maps = max(maps_per_mapper.values())
54
+ max_maps_mapper = [user_id for user_id in maps_per_mapper if maps_per_mapper[user_id] == max_maps]
55
+ print("Mapper with most number of maps:", max_maps_mapper)
56
+ print("Number of maps:", max_maps)
classifier/libs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .utils.model_utils import get_dataloaders, get_optimizer, get_scheduler, get_tokenizer
classifier/libs/dataset/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .ors_dataset import OrsDataset
2
+ from .osu_parser import OsuParser
3
+ from .data_utils import update_event_times
classifier/libs/dataset/data_utils.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ from pydub import AudioSegment
7
+
8
+ import numpy.typing as npt
9
+
10
+ from ..tokenizer import Event, EventType
11
+
12
+ MILISECONDS_PER_SECOND = 1000
13
+
14
+
15
+ def load_audio_file(file: Path, sample_rate: int, speed: float = 1.0) -> npt.NDArray:
16
+ """Load an audio file as a numpy time-series array
17
+
18
+ The signals are resampled, converted to mono channel, and normalized.
19
+
20
+ Args:
21
+ file: Path to audio file.
22
+ sample_rate: Sample rate to resample the audio.
23
+ speed: Speed multiplier for the audio.
24
+
25
+ Returns:
26
+ samples: Audio time series.
27
+ """
28
+ file = Path(file)
29
+ audio = AudioSegment.from_file(file, format=file.suffix[1:])
30
+ audio.frame_rate = int(audio.frame_rate * speed)
31
+ audio = audio.set_frame_rate(sample_rate)
32
+ audio = audio.set_channels(1)
33
+ samples = np.array(audio.get_array_of_samples()).astype(np.float32)
34
+ samples *= 1.0 / np.max(np.abs(samples))
35
+ return samples
36
+
37
+
38
+ def update_event_times(
39
+ events: list[Event],
40
+ event_times: list[int],
41
+ end_time: Optional[float] = None,
42
+ types_first: bool = False
43
+ ) -> None:
44
+ """Extends the event times list with the times of the new events if the event list is longer than the event times list.
45
+
46
+ Args:
47
+ events: List of events.
48
+ event_times: List of event times.
49
+ end_time: End time of the events, for interpolation.
50
+ types_first: If True, the type token is at the start of the group before the timeshift token.
51
+ """
52
+ non_timed_events = [
53
+ EventType.BEZIER_ANCHOR,
54
+ EventType.PERFECT_ANCHOR,
55
+ EventType.CATMULL_ANCHOR,
56
+ EventType.RED_ANCHOR,
57
+ ]
58
+ timed_events = [
59
+ EventType.CIRCLE,
60
+ EventType.SPINNER,
61
+ EventType.SPINNER_END,
62
+ EventType.SLIDER_HEAD,
63
+ EventType.LAST_ANCHOR,
64
+ EventType.SLIDER_END,
65
+ EventType.BEAT,
66
+ EventType.MEASURE,
67
+ ]
68
+
69
+ start_index = len(event_times)
70
+ end_index = len(events)
71
+ current_time = 0 if len(event_times) == 0 else event_times[-1]
72
+ for i in range(start_index, end_index):
73
+ if types_first:
74
+ if i + 1 < end_index and events[i + 1].type == EventType.TIME_SHIFT:
75
+ current_time = events[i + 1].value
76
+ elif events[i].type == EventType.TIME_SHIFT:
77
+ current_time = events[i].value
78
+ event_times.append(current_time)
79
+
80
+ # Interpolate time for control point events
81
+ interpolate = False
82
+ if types_first:
83
+ # Start-T-D-CP-D-CP-D-LCP-T-D-End-T-D
84
+ # 1-----1-1-1--1-1--1-7---7-7-9---9-9
85
+ # 1-----1-1-3--3-5--5-7---7-7-9---9-9
86
+ index = range(start_index, end_index)
87
+ current_time = 0 if len(event_times) == 0 else event_times[-1]
88
+ else:
89
+ # T-D-Start-D-CP-D-CP-T-D-LCP-T-D-End
90
+ # 1-1-1-----1-1--1-1--7-7--7--9-9-9--
91
+ # 1-1-1-----3-3--5-5--7-7--7--9-9-9--
92
+ index = range(end_index - 1, start_index - 1, -1)
93
+ current_time = end_time if end_time is not None else event_times[-1]
94
+ for i in index:
95
+ event = events[i]
96
+
97
+ if event.type in timed_events:
98
+ interpolate = False
99
+
100
+ if event.type in non_timed_events:
101
+ interpolate = True
102
+
103
+ if not interpolate:
104
+ current_time = event_times[i]
105
+ continue
106
+
107
+ if event.type not in non_timed_events:
108
+ event_times[i] = current_time
109
+ continue
110
+
111
+ # Find the time of the first timed event and the number of control points between
112
+ j = i
113
+ step = 1 if types_first else -1
114
+ count = 0
115
+ other_time = current_time
116
+ while 0 <= j < len(events):
117
+ event2 = events[j]
118
+ if event2.type == EventType.TIME_SHIFT:
119
+ other_time = event_times[j]
120
+ break
121
+ if event2.type in non_timed_events:
122
+ count += 1
123
+ j += step
124
+ if j < 0:
125
+ other_time = 0
126
+ if j >= len(events):
127
+ other_time = end_time if end_time is not None else event_times[-1]
128
+
129
+ # Interpolate the time
130
+ current_time = int((current_time - other_time) / (count + 1) * count + other_time)
131
+ event_times[i] = current_time
132
+
133
+
134
+ def merge_events(events1: list[Event], event_times1: list[int], events2: list[Event], event_times2: list[int]) -> tuple[list[Event], list[int]]:
135
+ """Merge two lists of events in a time sorted manner. Assumes both lists are sorted by time.
136
+
137
+ Args:
138
+ events1: List of events.
139
+ event_times1: List of event times.
140
+ events2: List of events.
141
+ event_times2: List of event times.
142
+
143
+ Returns:
144
+ merged_events: Merged list of events.
145
+ merged_event_times: Merged list of event times.
146
+ """
147
+ merged_events = []
148
+ merged_event_times = []
149
+ i = 0
150
+ j = 0
151
+
152
+ while i < len(events1) and j < len(events2):
153
+ t1 = event_times1[i]
154
+ t2 = event_times2[j]
155
+
156
+ if t1 <= t2:
157
+ merged_events.append(events1[i])
158
+ merged_event_times.append(t1)
159
+ i += 1
160
+ else:
161
+ merged_events.append(events2[j])
162
+ merged_event_times.append(t2)
163
+ j += 1
164
+
165
+ merged_events.extend(events1[i:])
166
+ merged_events.extend(events2[j:])
167
+ merged_event_times.extend(event_times1[i:])
168
+ merged_event_times.extend(event_times2[j:])
169
+ return merged_events, merged_event_times
170
+
171
+
172
+ def remove_events_of_type(events: list[Event], event_times: list[int], event_types: list[EventType]) -> tuple[list[Event], list[int]]:
173
+ """Remove all events of a specific type from a list of events.
174
+
175
+ Args:
176
+ events: List of events.
177
+ event_types: Types of event to remove.
178
+
179
+ Returns:
180
+ filtered_events: Filtered list of events.
181
+ """
182
+ new_events = []
183
+ new_event_times = []
184
+ for event, time in zip(events, event_times):
185
+ if event.type not in event_types:
186
+ new_events.append(event)
187
+ new_event_times.append(time)
188
+ return new_events, new_event_times
189
+
190
+
191
+ def speed_events(events: list[Event], event_times: list[int], speed: float) -> tuple[list[Event], list[int]]:
192
+ """Change the speed of a list of events.
193
+
194
+ Args:
195
+ events: List of events.
196
+ event_times: List of event times
197
+ speed: Speed multiplier.
198
+
199
+ Returns:
200
+ sped_events: Sped up list of events.
201
+ """
202
+ sped_events = []
203
+ for event in events:
204
+ if event.type == EventType.TIME_SHIFT:
205
+ event.value = int(event.value / speed)
206
+ sped_events.append(event)
207
+
208
+ sped_event_times = []
209
+ for t in event_times:
210
+ sped_event_times.append(int(t / speed))
211
+
212
+ return sped_events, sped_event_times
213
+
214
+
215
+ @dataclasses.dataclass
216
+ class Group:
217
+ event_type: EventType = None
218
+ time: int = 0
219
+ distance: int = None
220
+ x: float = None
221
+ y: float = None
222
+ new_combo: bool = False
223
+ hitsounds: list[int] = dataclasses.field(default_factory=list)
224
+ samplesets: list[int] = dataclasses.field(default_factory=list)
225
+ additions: list[int] = dataclasses.field(default_factory=list)
226
+ volumes: list[int] = dataclasses.field(default_factory=list)
227
+
228
+
229
+ type_events = [
230
+ EventType.CIRCLE,
231
+ EventType.SPINNER,
232
+ EventType.SPINNER_END,
233
+ EventType.SLIDER_HEAD,
234
+ EventType.BEZIER_ANCHOR,
235
+ EventType.PERFECT_ANCHOR,
236
+ EventType.CATMULL_ANCHOR,
237
+ EventType.RED_ANCHOR,
238
+ EventType.LAST_ANCHOR,
239
+ EventType.SLIDER_END,
240
+ EventType.BEAT,
241
+ EventType.MEASURE,
242
+ ]
243
+
244
+
245
+ def get_groups(
246
+ events: list[Event],
247
+ *,
248
+ event_times: Optional[list[int]] = None,
249
+ types_first: bool = False
250
+ ) -> list[Group]:
251
+ groups = []
252
+ group = Group()
253
+ for i, event in enumerate(events):
254
+ if event.type == EventType.TIME_SHIFT:
255
+ group.time = event.value
256
+ elif event.type == EventType.DISTANCE:
257
+ group.distance = event.value
258
+ elif event.type == EventType.POS_X:
259
+ group.x = event.value
260
+ elif event.type == EventType.POS_Y:
261
+ group.y = event.value
262
+ elif event.type == EventType.NEW_COMBO:
263
+ group.new_combo = True
264
+ elif event.type == EventType.HITSOUND:
265
+ group.hitsounds.append((event.value % 8) * 2)
266
+ group.samplesets.append(((event.value // 8) % 3) + 1)
267
+ group.additions.append(((event.value // 24) % 3) + 1)
268
+ elif event.type == EventType.VOLUME:
269
+ group.volumes.append(event.value)
270
+ elif event.type in type_events:
271
+ if types_first:
272
+ if group.event_type is not None:
273
+ groups.append(group)
274
+ group = Group()
275
+ group.event_type = event.type
276
+ if event_times is not None:
277
+ group.time = event_times[i]
278
+ else:
279
+ group.event_type = event.type
280
+ if event_times is not None:
281
+ group.time = event_times[i]
282
+ groups.append(group)
283
+ group = Group()
284
+
285
+ if group.event_type is not None:
286
+ groups.append(group)
287
+
288
+ return groups
289
+
290
+
291
+ def get_group_indices(events: list[Event], types_first: bool = False) -> list[list[int]]:
292
+ groups = []
293
+ indices = []
294
+ for i, event in enumerate(events):
295
+ indices.append(i)
296
+ if event.type in type_events:
297
+ if types_first:
298
+ if len(indices) > 1:
299
+ groups.append(indices[:-1])
300
+ indices = [indices[-1]]
301
+ else:
302
+ groups.append(indices)
303
+ indices = []
304
+
305
+ if len(indices) > 0:
306
+ groups.append(indices)
307
+
308
+ return groups
classifier/libs/dataset/ors_dataset.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import random
6
+ from typing import Optional, Callable
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ import numpy.typing as npt
11
+ import torch
12
+ from omegaconf import DictConfig
13
+ from slider import Beatmap
14
+ from torch.utils.data import IterableDataset
15
+
16
+ from .data_utils import load_audio_file
17
+ from .osu_parser import OsuParser
18
+ from ..tokenizer import Event, EventType, Tokenizer
19
+
20
+ OSZ_FILE_EXTENSION = ".osz"
21
+ AUDIO_FILE_NAME = "audio.mp3"
22
+ MILISECONDS_PER_SECOND = 1000
23
+ STEPS_PER_MILLISECOND = 0.1
24
+ LABEL_IGNORE_ID = -100
25
+
26
+
27
+ class OrsDataset(IterableDataset):
28
+ __slots__ = (
29
+ "path",
30
+ "start",
31
+ "end",
32
+ "args",
33
+ "parser",
34
+ "tokenizer",
35
+ "beatmap_files",
36
+ "test",
37
+ )
38
+
39
+ def __init__(
40
+ self,
41
+ args: DictConfig,
42
+ parser: OsuParser,
43
+ tokenizer: Tokenizer,
44
+ beatmap_files: Optional[list[Path]] = None,
45
+ test: bool = False,
46
+ ):
47
+ """Manage and process ORS dataset.
48
+
49
+ Attributes:
50
+ args: Data loading arguments.
51
+ parser: Instance of OsuParser class.
52
+ tokenizer: Instance of Tokenizer class.
53
+ beatmap_files: List of beatmap files to process. Overrides track index range.
54
+ test: Whether to load the test dataset.
55
+ """
56
+ super().__init__()
57
+ self.path = args.test_dataset_path if test else args.train_dataset_path
58
+ self.start = args.test_dataset_start if test else args.train_dataset_start
59
+ self.end = args.test_dataset_end if test else args.train_dataset_end
60
+ self.args = args
61
+ self.parser = parser
62
+ self.tokenizer = tokenizer
63
+ self.beatmap_files = beatmap_files
64
+ self.test = test
65
+
66
+ def _get_beatmap_files(self) -> list[Path]:
67
+ if self.beatmap_files is not None:
68
+ return self.beatmap_files
69
+
70
+ # Get a list of all beatmap files in the dataset path in the track index range between start and end
71
+ beatmap_files = []
72
+ track_names = ["Track" + str(i).zfill(5) for i in range(self.start, self.end)]
73
+ for track_name in track_names:
74
+ for beatmap_file in os.listdir(
75
+ os.path.join(self.path, track_name, "beatmaps"),
76
+ ):
77
+ beatmap_files.append(
78
+ Path(
79
+ os.path.join(
80
+ self.path,
81
+ track_name,
82
+ "beatmaps",
83
+ beatmap_file,
84
+ )
85
+ ),
86
+ )
87
+
88
+ return beatmap_files
89
+
90
+ def _get_track_paths(self) -> list[Path]:
91
+ track_paths = []
92
+ track_names = ["Track" + str(i).zfill(5) for i in range(self.start, self.end)]
93
+ for track_name in track_names:
94
+ track_paths.append(Path(os.path.join(self.path, track_name)))
95
+ return track_paths
96
+
97
+ def __iter__(self):
98
+ beatmap_files = self._get_track_paths() if self.args.per_track else self._get_beatmap_files()
99
+
100
+ if not self.test:
101
+ random.shuffle(beatmap_files)
102
+
103
+ if self.args.cycle_length > 1 and not self.test:
104
+ return InterleavingBeatmapDatasetIterable(
105
+ beatmap_files,
106
+ self._iterable_factory,
107
+ self.args.cycle_length,
108
+ )
109
+
110
+ return self._iterable_factory(beatmap_files).__iter__()
111
+
112
+ def _iterable_factory(self, beatmap_files: list[Path]):
113
+ return BeatmapDatasetIterable(
114
+ beatmap_files,
115
+ self.args,
116
+ self.parser,
117
+ self.tokenizer,
118
+ self.test,
119
+ )
120
+
121
+
122
+ class InterleavingBeatmapDatasetIterable:
123
+ __slots__ = ("workers", "cycle_length", "index")
124
+
125
+ def __init__(
126
+ self,
127
+ beatmap_files: list[Path],
128
+ iterable_factory: Callable,
129
+ cycle_length: int,
130
+ ):
131
+ per_worker = int(np.ceil(len(beatmap_files) / float(cycle_length)))
132
+ self.workers = [
133
+ iterable_factory(
134
+ beatmap_files[
135
+ i * per_worker: min(len(beatmap_files), (i + 1) * per_worker)
136
+ ]
137
+ ).__iter__()
138
+ for i in range(cycle_length)
139
+ ]
140
+ self.cycle_length = cycle_length
141
+ self.index = 0
142
+
143
+ def __iter__(self) -> "InterleavingBeatmapDatasetIterable":
144
+ return self
145
+
146
+ def __next__(self) -> tuple[any, int]:
147
+ num = len(self.workers)
148
+ for _ in range(num):
149
+ try:
150
+ self.index = self.index % len(self.workers)
151
+ item = self.workers[self.index].__next__()
152
+ self.index += 1
153
+ return item
154
+ except StopIteration:
155
+ self.workers.remove(self.workers[self.index])
156
+ raise StopIteration
157
+
158
+
159
+ class BeatmapDatasetIterable:
160
+ __slots__ = (
161
+ "beatmap_files",
162
+ "args",
163
+ "parser",
164
+ "tokenizer",
165
+ "test",
166
+ "frame_seq_len",
167
+ "pre_token_len",
168
+ "add_empty_sequences",
169
+ )
170
+
171
+ def __init__(
172
+ self,
173
+ beatmap_files: list[Path],
174
+ args: DictConfig,
175
+ parser: OsuParser,
176
+ tokenizer: Tokenizer,
177
+ test: bool,
178
+ ):
179
+ self.beatmap_files = beatmap_files
180
+ self.args = args
181
+ self.parser = parser
182
+ self.tokenizer = tokenizer
183
+ self.test = test
184
+ self.frame_seq_len = args.src_seq_len - 1
185
+
186
+ def _get_frames(self, samples: npt.NDArray) -> tuple[npt.NDArray, npt.NDArray]:
187
+ """Segment audio samples into frames.
188
+
189
+ Each frame has `frame_size` audio samples.
190
+ It will also calculate and return the time of each audio frame, in miliseconds.
191
+
192
+ Args:
193
+ samples: Audio time-series.
194
+
195
+ Returns:
196
+ frames: Audio frames.
197
+ frame_times: Audio frame times.
198
+ """
199
+ samples = np.pad(samples, [0, self.args.hop_length - len(samples) % self.args.hop_length])
200
+ frames = np.reshape(samples, (-1, self.args.hop_length))
201
+ frames_per_milisecond = (
202
+ self.args.sample_rate / self.args.hop_length / MILISECONDS_PER_SECOND
203
+ )
204
+ frame_times = np.arange(len(frames)) / frames_per_milisecond
205
+ return frames, frame_times
206
+
207
+ def _create_sequences(
208
+ self,
209
+ frames: npt.NDArray,
210
+ frame_times: npt.NDArray,
211
+ context: dict,
212
+ extra_data: Optional[dict] = None,
213
+ ) -> list[dict[str, int | npt.NDArray | list[Event]]]:
214
+ """Create frame and token sequences for training/testing.
215
+
216
+ Args:
217
+ frames: Audio frames.
218
+
219
+ Returns:
220
+ A list of source and target sequences.
221
+ """
222
+
223
+ def get_event_indices(events2: list[Event], event_times2: list[int]) -> tuple[list[int], list[int]]:
224
+ if len(events2) == 0:
225
+ return [], []
226
+
227
+ # Corresponding start event index for every audio frame.
228
+ start_indices = []
229
+ event_index = 0
230
+
231
+ for current_time in frame_times:
232
+ while event_index < len(events2) and event_times2[event_index] < current_time:
233
+ event_index += 1
234
+ start_indices.append(event_index)
235
+
236
+ # Corresponding end event index for every audio frame.
237
+ end_indices = start_indices[1:] + [len(events2)]
238
+
239
+ return start_indices, end_indices
240
+
241
+ start_indices, end_indices = get_event_indices(context["events"], context["event_times"])
242
+
243
+ sequences = []
244
+ n_frames = len(frames)
245
+ offset = random.randint(0, self.frame_seq_len)
246
+ # Divide audio frames into splits
247
+ for frame_start_idx in range(offset, n_frames, self.frame_seq_len):
248
+ frame_end_idx = min(frame_start_idx + self.frame_seq_len, n_frames)
249
+
250
+ def slice_events(context, frame_start_idx, frame_end_idx):
251
+ if len(context["events"]) == 0:
252
+ return []
253
+ event_start_idx = start_indices[frame_start_idx]
254
+ event_end_idx = end_indices[frame_end_idx - 1]
255
+ return context["events"][event_start_idx:event_end_idx]
256
+
257
+ def slice_context(context, frame_start_idx, frame_end_idx):
258
+ return {"events": slice_events(context, frame_start_idx, frame_end_idx)}
259
+
260
+ # Create the sequence
261
+ sequence = {
262
+ "time": frame_times[frame_start_idx],
263
+ "frames": frames[frame_start_idx:frame_end_idx],
264
+ "context": slice_context(context, frame_start_idx, frame_end_idx),
265
+ } | extra_data
266
+
267
+ sequences.append(sequence)
268
+
269
+ return sequences
270
+
271
+ def _normalize_time_shifts(self, sequence: dict) -> dict:
272
+ """Make all time shifts in the sequence relative to the start time of the sequence,
273
+ and normalize time values.
274
+
275
+ Args:
276
+ sequence: The input sequence.
277
+
278
+ Returns:
279
+ The same sequence with trimmed time shifts.
280
+ """
281
+
282
+ def process(events: list[Event], start_time) -> list[Event] | tuple[list[Event], int]:
283
+ for i, event in enumerate(events):
284
+ if event.type == EventType.TIME_SHIFT:
285
+ # We cant modify the event objects themselves because that will affect subsequent sequences
286
+ events[i] = Event(EventType.TIME_SHIFT, int((event.value - start_time) * STEPS_PER_MILLISECOND))
287
+
288
+ return events
289
+
290
+ start_time = sequence["time"]
291
+ del sequence["time"]
292
+
293
+ sequence["context"]["events"] = process(sequence["context"]["events"], start_time)
294
+
295
+ return sequence
296
+
297
+ def _tokenize_sequence(self, sequence: dict) -> dict:
298
+ """Tokenize the event sequence.
299
+
300
+ Begin token sequence with `[SOS]` token (start-of-sequence).
301
+ End token sequence with `[EOS]` token (end-of-sequence).
302
+
303
+ Args:
304
+ sequence: The input sequence.
305
+
306
+ Returns:
307
+ The same sequence with tokenized events.
308
+ """
309
+ context = sequence["context"]
310
+ tokens = torch.empty(len(context["events"]), dtype=torch.long)
311
+ for i, event in enumerate(context["events"]):
312
+ tokens[i] = self.tokenizer.encode(event)
313
+ context["tokens"] = tokens
314
+
315
+ return sequence
316
+
317
+ def _pad_and_split_token_sequence(self, sequence: dict) -> dict:
318
+ """Pad token sequence to a fixed length and split decoder input and labels.
319
+
320
+ Pad with `[PAD]` tokens until `tgt_seq_len`.
321
+
322
+ Token sequence (w/o last token) is the input to the transformer decoder,
323
+ token sequence (w/o first token) is the label, a.k.a. decoder ground truth.
324
+
325
+ Prefix the token sequence with the pre_tokens sequence.
326
+
327
+ Args:
328
+ sequence: The input sequence.
329
+
330
+ Returns:
331
+ The same sequence with padded tokens.
332
+ """
333
+ # Count reducible tokens, pre_tokens and context tokens
334
+ num_tokens = len(sequence["context"]["tokens"])
335
+
336
+ # Trim tokens to target sequence length
337
+ # n + padding = tgt_seq_len
338
+ n = min(self.args.tgt_seq_len, num_tokens)
339
+ si = 0
340
+
341
+ input_tokens = torch.full((self.args.tgt_seq_len,), self.tokenizer.pad_id, dtype=torch.long)
342
+
343
+ tokens = sequence["context"]["tokens"]
344
+
345
+ input_tokens[si:si + n] = tokens[:n]
346
+
347
+ # Randomize some input tokens
348
+ def randomize_tokens(tokens):
349
+ offset = torch.randint(low=-self.args.timing_random_offset, high=self.args.timing_random_offset + 1,
350
+ size=tokens.shape)
351
+ return torch.where((self.tokenizer.event_start[EventType.TIME_SHIFT] <= tokens) & (
352
+ tokens < self.tokenizer.event_end[EventType.TIME_SHIFT]),
353
+ torch.clamp(tokens + offset,
354
+ self.tokenizer.event_start[EventType.TIME_SHIFT],
355
+ self.tokenizer.event_end[EventType.TIME_SHIFT] - 1),
356
+ tokens)
357
+
358
+ if self.args.timing_random_offset > 0:
359
+ input_tokens[si:si + n] = randomize_tokens(input_tokens[si:si + n])
360
+
361
+ sequence["decoder_input_ids"] = input_tokens
362
+ sequence["decoder_attention_mask"] = input_tokens != self.tokenizer.pad_id
363
+
364
+ del sequence["context"]
365
+
366
+ return sequence
367
+
368
+ def _pad_frame_sequence(self, sequence: dict) -> dict:
369
+ """Pad frame sequence with zeros until `frame_seq_len`.
370
+
371
+ Frame sequence can be further processed into Mel spectrogram frames,
372
+ which is the input to the transformer encoder.
373
+
374
+ Args:
375
+ sequence: The input sequence.
376
+
377
+ Returns:
378
+ The same sequence with padded frames.
379
+ """
380
+ frames = torch.from_numpy(sequence["frames"]).to(torch.float32)
381
+
382
+ if frames.shape[0] != self.frame_seq_len:
383
+ n = min(self.frame_seq_len, len(frames))
384
+ padded_frames = torch.zeros(
385
+ self.frame_seq_len,
386
+ frames.shape[-1],
387
+ dtype=frames.dtype,
388
+ device=frames.device,
389
+ )
390
+ padded_frames[:n] = frames[:n]
391
+ sequence["frames"] = torch.flatten(padded_frames)
392
+ else:
393
+ sequence["frames"] = torch.flatten(frames)
394
+
395
+ return sequence
396
+
397
+ def __iter__(self):
398
+ return self._get_next_tracks() if self.args.per_track else self._get_next_beatmaps()
399
+
400
+ @staticmethod
401
+ def _load_metadata(track_path: Path) -> dict:
402
+ metadata_file = track_path / "metadata.json"
403
+ with open(metadata_file) as f:
404
+ return json.load(f)
405
+
406
+ def _get_difficulty(self, metadata: dict, beatmap_name: str, speed: float = 1.0, beatmap: Beatmap = None) -> float:
407
+ if beatmap is not None and (all(e == 1.5 for e in self.args.dt_augment_range) or speed not in [1.0, 1.5]):
408
+ return beatmap.stars(speed_scale=speed)
409
+
410
+ if speed == 1.5:
411
+ return metadata["Beatmaps"][beatmap_name]["StandardStarRating"]["64"]
412
+ return metadata["Beatmaps"][beatmap_name]["StandardStarRating"]["0"]
413
+
414
+ @staticmethod
415
+ def _get_idx(metadata: dict, beatmap_name: str):
416
+ return metadata["Beatmaps"][beatmap_name]["Index"]
417
+
418
+ def _get_speed_augment(self):
419
+ mi, ma = self.args.dt_augment_range
420
+ return random.random() * (ma - mi) + mi if random.random() < self.args.dt_augment_prob else 1.0
421
+
422
+ def _get_next_beatmaps(self) -> dict:
423
+ for beatmap_path in self.beatmap_files:
424
+ metadata = self._load_metadata(beatmap_path.parents[1])
425
+
426
+ if self.args.min_difficulty > 0 and self._get_difficulty(metadata,
427
+ beatmap_path.stem) < self.args.min_difficulty:
428
+ continue
429
+
430
+ speed = self._get_speed_augment()
431
+ audio_path = beatmap_path.parents[1] / list(beatmap_path.parents[1].glob('audio.*'))[0]
432
+ audio_samples = load_audio_file(audio_path, self.args.sample_rate, speed)
433
+
434
+ for sample in self._get_next_beatmap(audio_samples, beatmap_path, speed):
435
+ yield sample
436
+
437
+ def _get_next_tracks(self) -> dict:
438
+ for track_path in self.beatmap_files:
439
+ metadata = self._load_metadata(track_path)
440
+
441
+ if self.args.min_difficulty > 0 and all(self._get_difficulty(metadata, beatmap_name)
442
+ < self.args.min_difficulty for beatmap_name in
443
+ metadata["Beatmaps"]):
444
+ continue
445
+
446
+ speed = self._get_speed_augment()
447
+ audio_path = track_path / list(track_path.glob('audio.*'))[0]
448
+ audio_samples = load_audio_file(audio_path, self.args.sample_rate, speed)
449
+
450
+ for beatmap_name in metadata["Beatmaps"]:
451
+ beatmap_path = (track_path / "beatmaps" / beatmap_name).with_suffix(".osu")
452
+
453
+ if self.args.min_difficulty > 0 and self._get_difficulty(metadata,
454
+ beatmap_name) < self.args.min_difficulty:
455
+ continue
456
+
457
+ for sample in self._get_next_beatmap(audio_samples, beatmap_path, speed):
458
+ yield sample
459
+
460
+ def _get_next_beatmap(self, audio_samples, beatmap_path: Path, speed: float) -> dict:
461
+ frames, frame_times = self._get_frames(audio_samples)
462
+ osu_beatmap = Beatmap.from_path(beatmap_path)
463
+
464
+ if osu_beatmap.beatmap_id not in self.tokenizer.beatmap_mapper:
465
+ return
466
+
467
+ extra_data = {
468
+ "labels": self.tokenizer.mapper_idx[self.tokenizer.beatmap_mapper[osu_beatmap.beatmap_id]],
469
+ }
470
+
471
+ flip_x, flip_y = False, False
472
+ if self.args.augment_flip:
473
+ flip_x, flip_y = random.random() < 0.5, random.random() < 0.5
474
+
475
+ events, event_times = self.parser.parse(osu_beatmap, speed, flip_x, flip_y)
476
+ in_context = {"events": events, "event_times": event_times}
477
+
478
+ sequences = self._create_sequences(
479
+ frames,
480
+ frame_times,
481
+ in_context,
482
+ extra_data,
483
+ )
484
+
485
+ for sequence in sequences:
486
+ sequence = self._normalize_time_shifts(sequence)
487
+ sequence = self._tokenize_sequence(sequence)
488
+ sequence = self._pad_frame_sequence(sequence)
489
+ sequence = self._pad_and_split_token_sequence(sequence)
490
+ yield sequence
classifier/libs/dataset/osu_parser.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from datetime import timedelta
4
+ from typing import Tuple
5
+
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+ from omegaconf import DictConfig
9
+ from slider import Beatmap, Circle, Slider, Spinner
10
+ from slider.curve import Linear, Catmull, Perfect, MultiBezier
11
+
12
+ from ..tokenizer import Event, EventType, Tokenizer
13
+ from .data_utils import merge_events, speed_events
14
+
15
+
16
+ class OsuParser:
17
+ def __init__(self, args: DictConfig, tokenizer: Tokenizer) -> None:
18
+ self.types_first = args.data.types_first
19
+ self.add_timing = args.data.add_timing
20
+ self.add_snapping = args.data.add_snapping
21
+ self.add_timing_points = args.data.add_timing_points
22
+ self.add_hitsounds = args.data.add_hitsounds
23
+ self.add_distances = args.data.add_distances
24
+ self.add_positions = args.data.add_positions
25
+ if self.add_positions:
26
+ self.position_precision = args.data.position_precision
27
+ self.position_split_axes = args.data.position_split_axes
28
+ x_min, x_max, y_min, y_max = args.data.position_range
29
+ self.x_min = x_min / self.position_precision
30
+ self.x_max = x_max / self.position_precision
31
+ self.y_min = y_min / self.position_precision
32
+ self.y_max = y_max / self.position_precision
33
+ self.x_count = self.x_max - self.x_min + 1
34
+ if self.add_distances:
35
+ dist_range = tokenizer.event_range[EventType.DISTANCE]
36
+ self.dist_min = dist_range.min_value
37
+ self.dist_max = dist_range.max_value
38
+
39
+ def parse(
40
+ self,
41
+ beatmap: Beatmap,
42
+ speed: float = 1.0,
43
+ flip_x: bool = False,
44
+ flip_y: bool = False
45
+ ) -> tuple[list[Event], list[int]]:
46
+ # noinspection PyUnresolvedReferences
47
+ """Parse an .osu beatmap.
48
+
49
+ Each hit object is parsed into a list of Event objects, in order of its
50
+ appearance in the beatmap. In other words, in ascending order of time.
51
+
52
+ Args:
53
+ beatmap: Beatmap object parsed from an .osu file.
54
+ speed: Speed multiplier for the beatmap.
55
+ flip_x: Whether to flip the x-axis.
56
+ flip_y: Whether to flip the y-axis.
57
+
58
+ Returns:
59
+ events: List of Event object lists.
60
+ event_times: List of event times.
61
+
62
+ Example::
63
+ >>> beatmap = [
64
+ "64,80,11000,1,0",
65
+ "100,100,16000,2,0,B|200:200|250:200|250:200|300:150,2"
66
+ ]
67
+ >>> events = parse(beatmap)
68
+ >>> print(events)
69
+ [
70
+ Event(EventType.TIME_SHIFT, 11000), Event(EventType.DISTANCE, 36), Event(EventType.CIRCLE),
71
+ Event(EventType.TIME_SHIFT, 16000), Event(EventType.DISTANCE, 42), Event(EventType.SLIDER_HEAD),
72
+ Event(EventType.TIME_SHIFT, 16500), Event(EventType.DISTANCE, 141), Event(EventType.BEZIER_ANCHOR),
73
+ Event(EventType.TIME_SHIFT, 17000), Event(EventType.DISTANCE, 50), Event(EventType.BEZIER_ANCHOR),
74
+ Event(EventType.TIME_SHIFT, 17500), Event(EventType.DISTANCE, 10), Event(EventType.BEZIER_ANCHOR),
75
+ Event(EventType.TIME_SHIFT, 18000), Event(EventType.DISTANCE, 64), Event(EventType.LAST _ANCHOR),
76
+ Event(EventType.TIME_SHIFT, 20000), Event(EventType.DISTANCE, 11), Event(EventType.SLIDER_END)
77
+ ]
78
+ """
79
+ hit_objects = beatmap.hit_objects(stacking=False)
80
+ last_pos = np.array((256, 192))
81
+ events = []
82
+ event_times = []
83
+
84
+ for hit_object in hit_objects:
85
+ if isinstance(hit_object, Circle):
86
+ last_pos = self._parse_circle(hit_object, events, event_times, last_pos, beatmap, flip_x, flip_y)
87
+ elif isinstance(hit_object, Slider):
88
+ last_pos = self._parse_slider(hit_object, events, event_times, last_pos, beatmap, flip_x, flip_y)
89
+ elif isinstance(hit_object, Spinner):
90
+ last_pos = self._parse_spinner(hit_object, events, event_times, beatmap)
91
+
92
+ if self.add_timing:
93
+ timing_events, timing_times = self.parse_timing(beatmap)
94
+ events, event_times = merge_events(timing_events, timing_times, events, event_times)
95
+
96
+ if speed != 1.0:
97
+ events, event_times = speed_events(events, event_times, speed)
98
+
99
+ return events, event_times
100
+
101
+ def parse_timing(self, beatmap: Beatmap, speed: float = 1.0) -> tuple[list[Event], list[int]]:
102
+ """Extract all timing information from a beatmap."""
103
+ events = []
104
+ event_times = []
105
+ hit_objects = beatmap.hit_objects(stacking=False)
106
+ if len(hit_objects) == 0:
107
+ last_time = timedelta(milliseconds=0)
108
+ else:
109
+ last_ho = beatmap.hit_objects(stacking=False)[-1]
110
+ last_time = last_ho.end_time if hasattr(last_ho, "end_time") else last_ho.time
111
+
112
+ # Get all timing points with BPM changes
113
+ timing_points = [tp for tp in beatmap.timing_points if tp.bpm]
114
+
115
+ for i, tp in enumerate(timing_points):
116
+ # Generate beat and measure events until the next timing point
117
+ next_tp = timing_points[i + 1] if i + 1 < len(timing_points) else None
118
+ next_time = next_tp.offset - timedelta(milliseconds=10) if next_tp else last_time
119
+ time = tp.offset
120
+ measure_counter = 0
121
+ beat_delta = timedelta(milliseconds=tp.ms_per_beat)
122
+ while time <= next_time:
123
+ if self.add_timing_points and measure_counter == 0:
124
+ event_type = EventType.TIMING_POINT
125
+ elif measure_counter % tp.meter == 0:
126
+ event_type = EventType.MEASURE
127
+ else:
128
+ event_type = EventType.BEAT
129
+
130
+ self._add_group(
131
+ event_type,
132
+ time,
133
+ events,
134
+ event_times,
135
+ beatmap,
136
+ time_event=True,
137
+ add_snap=False,
138
+ )
139
+
140
+ measure_counter += 1
141
+ time += beat_delta
142
+
143
+ if speed != 1.0:
144
+ events, event_times = speed_events(events, event_times, speed)
145
+
146
+ return events, event_times
147
+
148
+ @staticmethod
149
+ def uninherited_point_at(time: timedelta, beatmap: Beatmap):
150
+ tp = beatmap.timing_point_at(time)
151
+ return tp if tp.parent is None else tp.parent
152
+
153
+ @staticmethod
154
+ def hitsound_point_at(time: timedelta, beatmap: Beatmap):
155
+ hs_query = time + timedelta(milliseconds=5)
156
+ return beatmap.timing_point_at(hs_query)
157
+
158
+ def _add_time_event(self, time: timedelta, beatmap: Beatmap, events: list[Event], event_times: list[int],
159
+ add_snap: bool = True) -> None:
160
+ """Add a snapping event to the event list.
161
+
162
+ Args:
163
+ time: Time of the snapping event.
164
+ beatmap: Beatmap object.
165
+ events: List of events to add to.
166
+ add_snap: Whether to add a snapping event.
167
+ """
168
+ time_ms = int(time.total_seconds() * 1000 + 1e-5)
169
+ events.append(Event(EventType.TIME_SHIFT, time_ms))
170
+ event_times.append(time_ms)
171
+
172
+ if not add_snap or not self.add_snapping:
173
+ return
174
+
175
+ if len(beatmap.timing_points) > 0:
176
+ tp = self.uninherited_point_at(time, beatmap)
177
+ beats = (time - tp.offset).total_seconds() * 1000 / tp.ms_per_beat
178
+ snapping = 0
179
+ for i in range(1, 17):
180
+ # If the difference between the time and the snapped time is less than 2 ms, that is the correct snapping
181
+ if abs(beats - round(beats * i) / i) * tp.ms_per_beat < 2:
182
+ snapping = i
183
+ break
184
+ else:
185
+ snapping = 0
186
+
187
+ events.append(Event(EventType.SNAPPING, snapping))
188
+ event_times.append(time_ms)
189
+
190
+ def _add_hitsound_event(self, time: timedelta, group_time: int, hitsound: int, addition: str, beatmap: Beatmap,
191
+ events: list[Event], event_times: list[int]) -> None:
192
+ if not self.add_hitsounds:
193
+ return
194
+
195
+ if len(beatmap.timing_points) > 0:
196
+ tp = self.hitsound_point_at(time, beatmap)
197
+ tp_sample_set = tp.sample_type if tp.sample_type != 0 else 2 # Inherit to soft sample set
198
+ tp_volume = tp.volume
199
+ else:
200
+ tp_sample_set = 2
201
+ tp_volume = 100
202
+
203
+ addition_split = addition.split(":")
204
+ sample_set = int(addition_split[0]) if addition_split[0] != "0" else tp_sample_set
205
+ addition_set = int(addition_split[1]) if addition_split[1] != "0" else sample_set
206
+
207
+ sample_set = sample_set if 0 < sample_set < 4 else 1 # Overflow default to normal sample set
208
+ addition_set = addition_set if 0 < addition_set < 4 else 1 # Overflow default to normal sample set
209
+ hitsound = hitsound & 14 # Only take the bits for normal, whistle, and finish
210
+
211
+ hitsound_idx = hitsound // 2 + 8 * (sample_set - 1) + 24 * (addition_set - 1)
212
+
213
+ events.append(Event(EventType.HITSOUND, hitsound_idx))
214
+ events.append(Event(EventType.VOLUME, tp_volume))
215
+ event_times.append(group_time)
216
+ event_times.append(group_time)
217
+
218
+ def _clip_dist(self, dist: int) -> int:
219
+ """Clip distance to valid range."""
220
+ return int(np.clip(dist, self.dist_min, self.dist_max))
221
+
222
+ def _scale_clip_pos(self, pos: npt.NDArray) -> Tuple[int, int]:
223
+ """Clip position to valid range."""
224
+ p = pos / self.position_precision
225
+ return int(np.clip(p[0], self.x_min, self.x_max)), int(np.clip(p[1], self.y_min, self.y_max))
226
+
227
+ def _add_position_event(self, pos: npt.NDArray, last_pos: npt.NDArray, time: timedelta, events: list[Event],
228
+ event_times: list[int], flip_x: bool, flip_y: bool) -> npt.NDArray:
229
+ time_ms = int(time.total_seconds() * 1000 + 1e-5)
230
+ if self.add_distances:
231
+ dist = self._clip_dist(np.linalg.norm(pos - last_pos))
232
+ events.append(Event(EventType.DISTANCE, dist))
233
+ event_times.append(time_ms)
234
+
235
+ if self.add_positions:
236
+ pos_modified = pos.copy()
237
+ if flip_x:
238
+ pos_modified[0] = 512 - pos_modified[0]
239
+ if flip_y:
240
+ pos_modified[1] = 384 - pos_modified[1]
241
+
242
+ p = self._scale_clip_pos(pos_modified)
243
+ if self.position_split_axes:
244
+ events.append(Event(EventType.POS_X, p[0]))
245
+ events.append(Event(EventType.POS_Y, p[1]))
246
+ event_times.append(time_ms)
247
+ event_times.append(time_ms)
248
+ else:
249
+ events.append(Event(EventType.POS, (p[0] - self.x_min) + (p[1] - self.y_min) * self.x_count))
250
+ event_times.append(time_ms)
251
+
252
+ return pos
253
+
254
+ def _add_group(
255
+ self,
256
+ event_type: EventType,
257
+ time: timedelta,
258
+ events: list[Event],
259
+ event_times: list[int],
260
+ beatmap: Beatmap,
261
+ *,
262
+ time_event: bool = False,
263
+ add_snap=True,
264
+ pos: npt.NDArray = None,
265
+ last_pos: npt.NDArray = None,
266
+ new_combo: bool = False,
267
+ hitsound_ref_times: list[timedelta] = None,
268
+ hitsounds: list[int] = None,
269
+ additions: list[str] = None,
270
+ flip_x: bool = False,
271
+ flip_y: bool = False,
272
+ ) -> npt.NDArray:
273
+ """Add a group of events to the event list."""
274
+ time_ms = int(time.total_seconds() * 1000 + 1e-5) if time is not None else None
275
+
276
+ if self.types_first:
277
+ events.append(Event(event_type))
278
+ event_times.append(time_ms)
279
+ if time_event:
280
+ self._add_time_event(time, beatmap, events, event_times, add_snap)
281
+ if pos is not None:
282
+ last_pos = self._add_position_event(pos, last_pos, time, events, event_times, flip_x, flip_y)
283
+ if new_combo:
284
+ events.append(Event(EventType.NEW_COMBO))
285
+ event_times.append(time_ms)
286
+ if hitsound_ref_times is not None:
287
+ for i, ref_time in enumerate(hitsound_ref_times):
288
+ self._add_hitsound_event(ref_time, time_ms, hitsounds[i], additions[i], beatmap, events, event_times)
289
+ if not self.types_first:
290
+ events.append(Event(event_type))
291
+ event_times.append(time_ms)
292
+
293
+ return last_pos
294
+
295
+ def _parse_circle(self, circle: Circle, events: list[Event], event_times: list[int], last_pos: npt.NDArray,
296
+ beatmap: Beatmap, flip_x: bool, flip_y: bool) -> npt.NDArray:
297
+ """Parse a circle hit object.
298
+
299
+ Args:
300
+ circle: Circle object.
301
+ events: List of events to add to.
302
+ last_pos: Last position of the hit objects.
303
+
304
+ Returns:
305
+ pos: Position of the circle.
306
+ """
307
+ return self._add_group(
308
+ EventType.CIRCLE,
309
+ circle.time,
310
+ events,
311
+ event_times,
312
+ beatmap,
313
+ time_event=True,
314
+ pos=np.array(circle.position),
315
+ last_pos=last_pos,
316
+ new_combo=circle.new_combo,
317
+ hitsound_ref_times=[circle.time],
318
+ hitsounds=[circle.hitsound],
319
+ additions=[circle.addition],
320
+ flip_x=flip_x,
321
+ flip_y=flip_y,
322
+ )
323
+
324
+ def _parse_slider(self, slider: Slider, events: list[Event], event_times: list[int], last_pos: npt.NDArray,
325
+ beatmap: Beatmap, flip_x: bool, flip_y: bool) -> npt.NDArray:
326
+ """Parse a slider hit object.
327
+
328
+ Args:
329
+ slider: Slider object.
330
+ events: List of events to add to.
331
+ last_pos: Last position of the hit objects.
332
+
333
+ Returns:
334
+ pos: Last position of the slider.
335
+ """
336
+ # Ignore sliders which are too big
337
+ if len(slider.curve.points) >= 100:
338
+ return last_pos
339
+
340
+ last_pos = self._add_group(
341
+ EventType.SLIDER_HEAD,
342
+ slider.time,
343
+ events,
344
+ event_times,
345
+ beatmap,
346
+ time_event=True,
347
+ pos=np.array(slider.position),
348
+ last_pos=last_pos,
349
+ new_combo=slider.new_combo,
350
+ hitsound_ref_times=[slider.time],
351
+ hitsounds=[slider.edge_sounds[0] if len(slider.edge_sounds) > 0 else 0],
352
+ additions=[slider.edge_additions[0] if len(slider.edge_additions) > 0 else '0:0'],
353
+ flip_x=flip_x,
354
+ flip_y=flip_y,
355
+ )
356
+
357
+ duration: timedelta = (slider.end_time - slider.time) / slider.repeat
358
+ control_point_count = len(slider.curve.points)
359
+
360
+ def append_control_points(event_type: EventType, last_pos: npt.NDArray = last_pos) -> npt.NDArray:
361
+ for i in range(1, control_point_count - 1):
362
+ last_pos = add_anchor(event_type, i, last_pos)
363
+
364
+ return last_pos
365
+
366
+ def add_anchor(event_type: EventType, i: int, last_pos: npt.NDArray) -> npt.NDArray:
367
+ return self._add_group(
368
+ event_type,
369
+ slider.time + i / (control_point_count - 1) * duration,
370
+ events,
371
+ event_times,
372
+ beatmap,
373
+ pos=np.array(slider.curve.points[i]),
374
+ last_pos=last_pos,
375
+ flip_x=flip_x,
376
+ flip_y=flip_y,
377
+ )
378
+
379
+ if isinstance(slider.curve, Linear):
380
+ last_pos = append_control_points(EventType.RED_ANCHOR, last_pos)
381
+ elif isinstance(slider.curve, Catmull):
382
+ last_pos = append_control_points(EventType.CATMULL_ANCHOR, last_pos)
383
+ elif isinstance(slider.curve, Perfect):
384
+ last_pos = append_control_points(EventType.PERFECT_ANCHOR, last_pos)
385
+ elif isinstance(slider.curve, MultiBezier):
386
+ for i in range(1, control_point_count - 1):
387
+ if slider.curve.points[i] == slider.curve.points[i + 1]:
388
+ last_pos = add_anchor(EventType.RED_ANCHOR, i, last_pos)
389
+ elif slider.curve.points[i] != slider.curve.points[i - 1]:
390
+ last_pos = add_anchor(EventType.BEZIER_ANCHOR, i, last_pos)
391
+
392
+ # Add body hitsounds and remaining edge hitsounds
393
+ last_pos = self._add_group(
394
+ EventType.LAST_ANCHOR,
395
+ slider.time + duration,
396
+ events,
397
+ event_times,
398
+ beatmap,
399
+ time_event=True,
400
+ pos=np.array(slider.curve.points[-1]),
401
+ last_pos=last_pos,
402
+ hitsound_ref_times=[slider.time + timedelta(milliseconds=1)] + [slider.time + i * duration for i in
403
+ range(1, slider.repeat)],
404
+ hitsounds=[slider.hitsound] + [slider.edge_sounds[i] if len(slider.edge_sounds) > i else 0 for i in
405
+ range(1, slider.repeat)],
406
+ additions=[slider.addition] + [slider.edge_additions[i] if len(slider.edge_additions) > i else '0:0' for i
407
+ in range(1, slider.repeat)],
408
+ flip_x=flip_x,
409
+ flip_y=flip_y,
410
+ )
411
+
412
+ return self._add_group(
413
+ EventType.SLIDER_END,
414
+ slider.end_time,
415
+ events,
416
+ event_times,
417
+ beatmap,
418
+ time_event=True,
419
+ pos=np.array(slider.curve(1)),
420
+ last_pos=last_pos,
421
+ hitsound_ref_times=[slider.end_time],
422
+ hitsounds=[slider.edge_sounds[-1] if len(slider.edge_sounds) > 0 else 0],
423
+ additions=[slider.edge_additions[-1] if len(slider.edge_additions) > 0 else '0:0'],
424
+ flip_x=flip_x,
425
+ flip_y=flip_y,
426
+ )
427
+
428
+ def _parse_spinner(self, spinner: Spinner, events: list[Event], event_times: list[int],
429
+ beatmap: Beatmap) -> npt.NDArray:
430
+ """Parse a spinner hit object.
431
+
432
+ Args:
433
+ spinner: Spinner object.
434
+ events: List of events to add to.
435
+
436
+ Returns:
437
+ pos: Last position of the spinner.
438
+ """
439
+ self._add_group(
440
+ EventType.SPINNER,
441
+ spinner.time,
442
+ events,
443
+ event_times,
444
+ beatmap,
445
+ time_event=True,
446
+ )
447
+
448
+ self._add_group(
449
+ EventType.SPINNER_END,
450
+ spinner.end_time,
451
+ events,
452
+ event_times,
453
+ beatmap,
454
+ time_event=True,
455
+ hitsound_ref_times=[spinner.end_time],
456
+ hitsounds=[spinner.hitsound],
457
+ additions=[spinner.addition],
458
+ )
459
+
460
+ return np.array((256, 192))
classifier/libs/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import OsuClassifier
classifier/libs/model/model.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from omegaconf import DictConfig
9
+ from transformers import T5Config, WhisperConfig, T5Model, WhisperModel
10
+ from transformers.modeling_outputs import Seq2SeqModelOutput
11
+
12
+ from .spectrogram import MelSpectrogram
13
+ from ..tokenizer import Tokenizer
14
+
15
+ LABEL_IGNORE_ID = -100
16
+
17
+
18
+ @dataclass
19
+ class OsuClassifierOutput:
20
+ loss: Optional[torch.FloatTensor] = None
21
+ logits: Optional[torch.FloatTensor] = None
22
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
23
+ decoder_last_hidden_state: Optional[torch.FloatTensor] = None
24
+ feature_vector: Optional[torch.FloatTensor] = None
25
+
26
+
27
+ def get_backbone_model(args, tokenizer: Tokenizer):
28
+ if args.model.name.startswith("google/t5"):
29
+ config = T5Config.from_pretrained(args.model.name)
30
+ elif args.model.name.startswith("openai/whisper"):
31
+ config = WhisperConfig.from_pretrained(args.model.name)
32
+ else:
33
+ raise NotImplementedError
34
+
35
+ config.vocab_size = tokenizer.vocab_size
36
+
37
+ if hasattr(args.model, "overwrite"):
38
+ for k, v in args.model.overwrite.items():
39
+ assert hasattr(config, k), f"config does not have attribute {k}"
40
+ setattr(config, k, v)
41
+
42
+ if hasattr(args.model, "add_config"):
43
+ for k, v in args.model.add_config.items():
44
+ assert not hasattr(config, k), f"config already has attribute {k}"
45
+ setattr(config, k, v)
46
+
47
+ if args.model.name.startswith("google/t5"):
48
+ model = T5Model(config)
49
+ elif args.model.name.startswith("openai/whisper"):
50
+ config.use_cache = False
51
+ config.num_mel_bins = config.d_model
52
+ config.pad_token_id = tokenizer.pad_id
53
+ config.max_source_positions = args.data.src_seq_len // 2
54
+ config.max_target_positions = args.data.tgt_seq_len
55
+ model = WhisperModel(config)
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ return model, config.d_model
60
+
61
+
62
+ class OsuClassifier(nn.Module):
63
+ __slots__ = [
64
+ "spectrogram",
65
+ "decoder_embedder",
66
+ "encoder_embedder",
67
+ "transformer",
68
+ "style_embedder",
69
+ "num_classes",
70
+ "input_features",
71
+ "projector",
72
+ "classifier",
73
+ "vocab_size",
74
+ "loss_fn",
75
+ ]
76
+
77
+ def __init__(self, args: DictConfig, tokenizer: Tokenizer):
78
+ super().__init__()
79
+
80
+ self.transformer, d_model = get_backbone_model(args, tokenizer)
81
+ self.num_classes = tokenizer.num_classes
82
+ self.input_features = args.model.input_features
83
+
84
+ self.decoder_embedder = nn.Embedding(tokenizer.vocab_size, d_model)
85
+ self.decoder_embedder.weight.data.normal_(mean=0.0, std=1.0)
86
+
87
+ self.spectrogram = MelSpectrogram(
88
+ args.model.spectrogram.sample_rate, args.model.spectrogram.n_fft,
89
+ args.model.spectrogram.n_mels, args.model.spectrogram.hop_length
90
+ )
91
+
92
+ self.encoder_embedder = nn.Linear(args.model.spectrogram.n_mels, d_model)
93
+
94
+ self.projector = nn.Linear(d_model, args.model.classifier_proj_size)
95
+ self.classifier = nn.Linear(args.model.classifier_proj_size, tokenizer.num_classes)
96
+
97
+ self.vocab_size = tokenizer.vocab_size
98
+ self.loss_fn = nn.CrossEntropyLoss()
99
+
100
+ def forward(
101
+ self,
102
+ frames: Optional[torch.FloatTensor] = None,
103
+ decoder_input_ids: Optional[torch.Tensor] = None,
104
+ labels: Optional[torch.LongTensor] = None,
105
+ **kwargs
106
+ ) -> OsuClassifierOutput:
107
+ """
108
+ frames: B x L_encoder x mel_bins, float32
109
+ decoder_input_ids: B x L_decoder, int64
110
+ beatmap_id: B, int64
111
+ encoder_outputs: B x L_encoder x D, float32
112
+ """
113
+
114
+ frames = self.spectrogram(frames) # (N, L, M)
115
+ inputs_embeds = self.encoder_embedder(frames)
116
+ decoder_inputs_embeds = self.decoder_embedder(decoder_input_ids)
117
+
118
+ if self.input_features:
119
+ input_features = torch.swapaxes(inputs_embeds, 1, 2) if inputs_embeds is not None else None
120
+ # noinspection PyTypeChecker
121
+ base_output: Seq2SeqModelOutput = self.transformer.forward(input_features=input_features,
122
+ decoder_inputs_embeds=decoder_inputs_embeds,
123
+ **kwargs)
124
+ else:
125
+ base_output = self.transformer.forward(inputs_embeds=inputs_embeds,
126
+ decoder_inputs_embeds=decoder_inputs_embeds,
127
+ **kwargs)
128
+
129
+ # Get logits
130
+ hidden_states = self.projector(base_output.last_hidden_state)
131
+ pooled_output = hidden_states.mean(dim=1)
132
+
133
+ logits = self.classifier(pooled_output)
134
+ loss = None
135
+
136
+ if labels is not None:
137
+ loss = self.loss_fn(logits.view(-1, self.num_classes), labels.view(-1))
138
+
139
+ return OsuClassifierOutput(
140
+ loss=loss,
141
+ logits=logits,
142
+ encoder_last_hidden_state=base_output.encoder_last_hidden_state,
143
+ decoder_last_hidden_state=base_output.last_hidden_state,
144
+ feature_vector=pooled_output
145
+ )
classifier/libs/model/spectrogram.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from nnAudio import features
6
+
7
+
8
+ class MelSpectrogram(nn.Module):
9
+ def __init__(
10
+ self,
11
+ sample_rate: int = 16000,
12
+ n_ftt: int = 2048,
13
+ n_mels: int = 512,
14
+ hop_length: int = 128,
15
+ ):
16
+ """
17
+ Melspectrogram transformation layer, supports on-the-fly processing on GPU.
18
+
19
+ Attributes:
20
+ sample_rate: The sampling rate for the input audio.
21
+ n_ftt: The window size for the STFT.
22
+ n_mels: The number of Mel filter banks.
23
+ hop_length: The hop (or stride) size.
24
+ """
25
+ super().__init__()
26
+ self.transform = features.MelSpectrogram(
27
+ sr=sample_rate,
28
+ n_fft=n_ftt,
29
+ n_mels=n_mels,
30
+ hop_length=hop_length,
31
+ center=True,
32
+ fmin=0,
33
+ fmax=sample_rate // 2,
34
+ pad_mode="constant",
35
+ )
36
+
37
+ def forward(self, samples: torch.tensor) -> torch.tensor:
38
+ """
39
+ Convert a batch of audio frames into a batch of Mel spectrogram frames.
40
+
41
+ For each item in the batch:
42
+ 1. pad left and right ends of audio by n_fft // 2.
43
+ 2. run STFT with window size of |n_ftt| and stride of |hop_length|.
44
+ 3. convert result into mel-scale.
45
+ 4. therefore, n_frames = n_samples // hop_length + 1.
46
+
47
+ Args:
48
+ samples: Audio time-series (batch size, n_samples).
49
+
50
+ Returns:
51
+ A batch of Mel spectrograms of size (batch size, n_frames, n_mels).
52
+ """
53
+ spectrogram = self.transform(samples)
54
+ spectrogram = spectrogram.permute(0, 2, 1)
55
+ return spectrogram
classifier/libs/tokenizer/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .event import *
2
+ from .tokenizer import Tokenizer
classifier/libs/tokenizer/event.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ from enum import Enum
5
+
6
+
7
+ class EventType(Enum):
8
+ TIME_SHIFT = "t"
9
+ SNAPPING = "snap"
10
+ DISTANCE = "dist"
11
+ NEW_COMBO = "new_combo"
12
+ HITSOUND = "hitsound"
13
+ VOLUME = "volume"
14
+ CIRCLE = "circle"
15
+ SPINNER = "spinner"
16
+ SPINNER_END = "spinner_end"
17
+ SLIDER_HEAD = "slider_head"
18
+ BEZIER_ANCHOR = "bezier_anchor"
19
+ PERFECT_ANCHOR = "perfect_anchor"
20
+ CATMULL_ANCHOR = "catmull_anchor"
21
+ RED_ANCHOR = "red_anchor"
22
+ LAST_ANCHOR = "last_anchor"
23
+ SLIDER_END = "slider_end"
24
+ BEAT = "beat"
25
+ MEASURE = "measure"
26
+ TIMING_POINT = "timing_point"
27
+ STYLE = "style"
28
+ DIFFICULTY = "difficulty"
29
+ MAPPER = "mapper"
30
+ DESCRIPTOR = "descriptor"
31
+ POS_X = "pos_x"
32
+ POS_Y = "pos_y"
33
+ POS = "pos"
34
+ CS = "cs"
35
+
36
+
37
+ @dataclasses.dataclass
38
+ class EventRange:
39
+ type: EventType
40
+ min_value: int
41
+ max_value: int
42
+
43
+
44
+ @dataclasses.dataclass
45
+ class Event:
46
+ type: EventType
47
+ value: int = 0
48
+
49
+ def __repr__(self) -> str:
50
+ return f"{self.type.value}{self.value}"
51
+
52
+ def __str__(self) -> str:
53
+ return f"{self.type.value}{self.value}"
classifier/libs/tokenizer/tokenizer.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ from omegaconf import DictConfig
5
+
6
+ from .event import Event, EventType, EventRange
7
+
8
+ MILISECONDS_PER_SECOND = 1000
9
+ MILISECONDS_PER_STEP = 10
10
+
11
+
12
+ class Tokenizer:
13
+ __slots__ = [
14
+ "offset",
15
+ "event_ranges",
16
+ "input_event_ranges",
17
+ "num_classes",
18
+ "num_diff_classes",
19
+ "max_difficulty",
20
+ "event_range",
21
+ "event_start",
22
+ "event_end",
23
+ "vocab_size",
24
+ "beatmap_idx",
25
+ "mapper_idx",
26
+ "beatmap_mapper",
27
+ "num_mapper_classes",
28
+ "beatmap_descriptors",
29
+ "descriptor_idx",
30
+ "num_descriptor_classes",
31
+ "num_cs_classes",
32
+ ]
33
+
34
+ def __init__(self, args: DictConfig = None):
35
+ """Fixed vocabulary tokenizer."""
36
+ self.offset = 1
37
+ self.event_ranges: list[EventRange] = [
38
+ EventRange(EventType.TIME_SHIFT, 0, 1024),
39
+ EventRange(EventType.SNAPPING, 0, 16),
40
+ EventRange(EventType.DISTANCE, 0, 640),
41
+ ]
42
+ self.num_classes = 0
43
+ self.beatmap_mapper: dict[int, int] = {} # beatmap_id -> mapper_id
44
+ self.mapper_idx: dict[int, int] = {} # mapper_id -> mapper_idx
45
+
46
+ if args is not None:
47
+ miliseconds_per_sequence = ((args.data.src_seq_len - 1) * args.model.spectrogram.hop_length *
48
+ MILISECONDS_PER_SECOND / args.model.spectrogram.sample_rate)
49
+ max_time_shift = int(miliseconds_per_sequence / MILISECONDS_PER_STEP)
50
+ min_time_shift = 0
51
+
52
+ self.event_ranges = [
53
+ EventRange(EventType.TIME_SHIFT, min_time_shift, max_time_shift),
54
+ EventRange(EventType.SNAPPING, 0, 16),
55
+ ]
56
+
57
+ self._init_mapper_idx(args)
58
+
59
+ if args.data.add_distances:
60
+ self.event_ranges.append(EventRange(EventType.DISTANCE, 0, 640))
61
+
62
+ if args.data.add_positions:
63
+ p = args.data.position_precision
64
+ x_min, x_max, y_min, y_max = args.data.position_range
65
+ x_min, x_max, y_min, y_max = x_min // p, x_max // p, y_min // p, y_max // p
66
+
67
+ if args.data.position_split_axes:
68
+ self.event_ranges.append(EventRange(EventType.POS_X, x_min, x_max))
69
+ self.event_ranges.append(EventRange(EventType.POS_Y, y_min, y_max))
70
+ else:
71
+ x_count = x_max - x_min + 1
72
+ y_count = y_max - y_min + 1
73
+ self.event_ranges.append(EventRange(EventType.POS, 0, x_count * y_count - 1))
74
+
75
+ self.event_ranges: list[EventRange] = self.event_ranges + [
76
+ EventRange(EventType.NEW_COMBO, 0, 0),
77
+ EventRange(EventType.HITSOUND, 0, 2 ** 3 * 3 * 3),
78
+ EventRange(EventType.VOLUME, 0, 100),
79
+ EventRange(EventType.CIRCLE, 0, 0),
80
+ EventRange(EventType.SPINNER, 0, 0),
81
+ EventRange(EventType.SPINNER_END, 0, 0),
82
+ EventRange(EventType.SLIDER_HEAD, 0, 0),
83
+ EventRange(EventType.BEZIER_ANCHOR, 0, 0),
84
+ EventRange(EventType.PERFECT_ANCHOR, 0, 0),
85
+ EventRange(EventType.CATMULL_ANCHOR, 0, 0),
86
+ EventRange(EventType.RED_ANCHOR, 0, 0),
87
+ EventRange(EventType.LAST_ANCHOR, 0, 0),
88
+ EventRange(EventType.SLIDER_END, 0, 0),
89
+ EventRange(EventType.BEAT, 0, 0),
90
+ EventRange(EventType.MEASURE, 0, 0),
91
+ ]
92
+
93
+ if args is not None and args.data.add_timing_points:
94
+ self.event_ranges.append(EventRange(EventType.TIMING_POINT, 0, 0))
95
+
96
+ self.event_range: dict[EventType, EventRange] = {er.type: er for er in self.event_ranges}
97
+
98
+ self.event_start: dict[EventType, int] = {}
99
+ self.event_end: dict[EventType, int] = {}
100
+ offset = self.offset
101
+ for er in self.event_ranges:
102
+ self.event_start[er.type] = offset
103
+ offset += er.max_value - er.min_value + 1
104
+ self.event_end[er.type] = offset
105
+
106
+ self.vocab_size: int = self.offset + sum(
107
+ er.max_value - er.min_value + 1 for er in self.event_ranges
108
+ )
109
+
110
+ @property
111
+ def pad_id(self) -> int:
112
+ """[PAD] token for padding."""
113
+ return 0
114
+
115
+ def decode(self, token_id: int) -> Event:
116
+ """Converts token ids into Event objects."""
117
+ offset = self.offset
118
+ for er in self.event_ranges:
119
+ if offset <= token_id <= offset + er.max_value - er.min_value:
120
+ return Event(type=er.type, value=er.min_value + token_id - offset)
121
+ offset += er.max_value - er.min_value + 1
122
+ for er in self.input_event_ranges:
123
+ if offset <= token_id <= offset + er.max_value - er.min_value:
124
+ return Event(type=er.type, value=er.min_value + token_id - offset)
125
+ offset += er.max_value - er.min_value + 1
126
+
127
+ raise ValueError(f"id {token_id} is not mapped to any event")
128
+
129
+ def encode(self, event: Event) -> int:
130
+ """Converts Event objects into token ids."""
131
+ if event.type not in self.event_range:
132
+ raise ValueError(f"unknown event type: {event.type}")
133
+
134
+ er = self.event_range[event.type]
135
+ offset = self.event_start[event.type]
136
+
137
+ if not er.min_value <= event.value <= er.max_value:
138
+ raise ValueError(
139
+ f"event value {event.value} is not within range "
140
+ f"[{er.min_value}, {er.max_value}] for event type {event.type}"
141
+ )
142
+
143
+ return offset + event.value - er.min_value
144
+
145
+ def event_type_range(self, event_type: EventType) -> tuple[int, int]:
146
+ """Get the token id range of each Event type."""
147
+ if event_type not in self.event_range:
148
+ raise ValueError(f"unknown event type: {event_type}")
149
+
150
+ er = self.event_range[event_type]
151
+ offset = self.event_start[event_type]
152
+ return offset, offset + (er.max_value - er.min_value)
153
+
154
+ def _init_mapper_idx(self, args):
155
+ """"Indexes beatmap mappers and mapper idx."""
156
+ if args is None or "mappers_path" not in args.data:
157
+ raise ValueError("mappers_path not found in args")
158
+
159
+ path = Path(args.data.mappers_path)
160
+
161
+ if not path.exists():
162
+ raise ValueError(f"mappers_path {path} not found")
163
+
164
+ # Load JSON data from file
165
+ with open(path, 'r') as file:
166
+ data = json.load(file)
167
+
168
+ # Populate beatmap_mapper
169
+ for item in data:
170
+ self.beatmap_mapper[item['id']] = item['user_id']
171
+
172
+ # Get unique user_ids from beatmap_mapper values
173
+ unique_user_ids = list(set(self.beatmap_mapper.values()))
174
+
175
+ # Create mapper_idx
176
+ self.mapper_idx = {user_id: idx for idx, user_id in enumerate(unique_user_ids)}
177
+ self.num_classes = len(unique_user_ids)
178
+
179
+ def state_dict(self):
180
+ return {
181
+ "offset": self.offset,
182
+ "event_ranges": self.event_ranges,
183
+ "num_classes": self.num_classes,
184
+ "event_range": self.event_range,
185
+ "event_start": self.event_start,
186
+ "event_end": self.event_end,
187
+ "vocab_size": self.vocab_size,
188
+ "beatmap_mapper": self.beatmap_mapper,
189
+ "mapper_idx": self.mapper_idx,
190
+ }
191
+
192
+ def load_state_dict(self, state_dict):
193
+ self.offset = state_dict["offset"]
194
+ self.event_ranges = state_dict["event_ranges"]
195
+ self.num_classes = state_dict["num_classes"]
196
+ self.event_range = state_dict["event_range"]
197
+ self.event_start = state_dict["event_start"]
198
+ self.event_end = state_dict["event_end"]
199
+ self.vocab_size = state_dict["vocab_size"]
200
+ self.beatmap_mapper = state_dict["beatmap_mapper"]
201
+ self.mapper_idx = state_dict["mapper_idx"]
classifier/libs/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model_utils import *
classifier/libs/utils/model_utils.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import lightning
5
+ import numpy as np
6
+ import torch
7
+ import torchmetrics
8
+ from omegaconf import DictConfig
9
+ from torch.optim import Optimizer, AdamW
10
+ from torch.optim.lr_scheduler import (
11
+ LRScheduler,
12
+ SequentialLR,
13
+ LinearLR,
14
+ CosineAnnealingLR,
15
+ )
16
+ from torch.utils.data import DataLoader
17
+ from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput
18
+ from transformers.utils import cached_file
19
+
20
+ import routed_pickle
21
+
22
+ from ..dataset import OrsDataset, OsuParser
23
+ from ..model import OsuClassifier
24
+ from ..model.model import OsuClassifierOutput
25
+ from ..tokenizer import Tokenizer
26
+
27
+
28
+ class LitOsuClassifier(lightning.LightningModule):
29
+ def __init__(self, args: DictConfig, tokenizer):
30
+ super().__init__()
31
+ self.save_hyperparameters()
32
+ self.args = args
33
+ self.model: OsuClassifier = OsuClassifier(args, tokenizer)
34
+
35
+ def forward(self, **kwargs) -> OsuClassifierOutput:
36
+ return self.model(**kwargs)
37
+
38
+ def training_step(self, batch, batch_idx):
39
+ output: Seq2SeqSequenceClassifierOutput = self.model(**batch)
40
+ loss = output.loss
41
+ self.log("train_loss", loss)
42
+ return loss
43
+
44
+ def testy_step(self, batch, batch_idx, prefix):
45
+ output: Seq2SeqSequenceClassifierOutput = self.model(**batch)
46
+ loss = output.loss
47
+ preds = output.logits.argmax(dim=1)
48
+ labels = batch["labels"]
49
+ accuracy = torchmetrics.functional.accuracy(preds, labels, "multiclass", num_classes=self.args.data.num_classes)
50
+ accuracy_10 = torchmetrics.functional.accuracy(output.logits, labels, "multiclass", num_classes=self.args.data.num_classes, top_k=10)
51
+ accuracy_100 = torchmetrics.functional.accuracy(output.logits, labels, "multiclass", num_classes=self.args.data.num_classes, top_k=100)
52
+ self.log(f"{prefix}_loss", loss)
53
+ self.log(f"{prefix}_accuracy", accuracy)
54
+ self.log(f"{prefix}_top_10_accuracy", accuracy_10)
55
+ self.log(f"{prefix}_top_100_accuracy", accuracy_100)
56
+ return loss
57
+
58
+ def validation_step(self, batch, batch_idx):
59
+ return self.testy_step(batch, batch_idx, "val")
60
+
61
+ def test_step(self, batch, batch_idx):
62
+ return self.testy_step(batch, batch_idx, "test")
63
+
64
+ def configure_optimizers(self):
65
+ optimizer = get_optimizer(self.parameters(), self.args)
66
+ scheduler = get_scheduler(optimizer, self.args)
67
+ return {"optimizer": optimizer, "lr_scheduler": {
68
+ "scheduler": scheduler,
69
+ "interval": "step",
70
+ "frequency": 1,
71
+ }}
72
+
73
+
74
+ def load_ckpt(ckpt_path, route_pickle=True):
75
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
+
77
+ if not os.path.exists(ckpt_path):
78
+ ckpt_path = cached_file(ckpt_path, "model.ckpt")
79
+ else:
80
+ ckpt_path = Path(ckpt_path)
81
+
82
+ checkpoint = torch.load(
83
+ ckpt_path,
84
+ map_location=lambda storage, loc: storage,
85
+ weights_only=False,
86
+ pickle_module=routed_pickle if route_pickle else None
87
+ )
88
+ tokenizer = checkpoint["hyper_parameters"]["tokenizer"]
89
+ model_args = checkpoint["hyper_parameters"]["args"]
90
+ state_dict = checkpoint["state_dict"]
91
+ non_compiled_state_dict = {}
92
+ for k, v in state_dict.items():
93
+ if k.startswith("model._orig_mod."):
94
+ non_compiled_state_dict["model." + k[16:]] = v
95
+ else:
96
+ non_compiled_state_dict[k] = v
97
+
98
+ model = LitOsuClassifier(model_args, tokenizer)
99
+ model.load_state_dict(non_compiled_state_dict)
100
+ model.eval().to(device)
101
+ return model, model_args, tokenizer
102
+
103
+
104
+ def get_tokenizer(args: DictConfig) -> Tokenizer:
105
+ return Tokenizer(args)
106
+
107
+
108
+ def get_optimizer(parameters, args: DictConfig) -> Optimizer:
109
+ if args.optim.name == 'adamw':
110
+ optimizer = AdamW(
111
+ parameters,
112
+ lr=args.optim.base_lr,
113
+ )
114
+ else:
115
+ raise NotImplementedError
116
+
117
+ return optimizer
118
+
119
+
120
+ def get_scheduler(optimizer: Optimizer, args: DictConfig, num_processes=1) -> LRScheduler:
121
+ scheduler_p1 = LinearLR(
122
+ optimizer,
123
+ start_factor=0.5,
124
+ end_factor=1,
125
+ total_iters=args.optim.warmup_steps * num_processes,
126
+ last_epoch=-1,
127
+ )
128
+
129
+ scheduler_p2 = CosineAnnealingLR(
130
+ optimizer,
131
+ T_max=args.optim.total_steps * num_processes - args.optim.warmup_steps * num_processes,
132
+ eta_min=args.optim.final_cosine,
133
+ )
134
+
135
+ scheduler = SequentialLR(
136
+ optimizer,
137
+ schedulers=[scheduler_p1, scheduler_p2],
138
+ milestones=[args.optim.warmup_steps * num_processes],
139
+ )
140
+
141
+ return scheduler
142
+
143
+
144
+ def get_dataloaders(tokenizer: Tokenizer, args: DictConfig) -> tuple[DataLoader, DataLoader]:
145
+ parser = OsuParser(args, tokenizer)
146
+ dataset = {
147
+ "train": OrsDataset(
148
+ args.data,
149
+ parser,
150
+ tokenizer,
151
+ ),
152
+ "test": OrsDataset(
153
+ args.data,
154
+ parser,
155
+ tokenizer,
156
+ test=True,
157
+ ),
158
+ }
159
+
160
+ dataloaders = {}
161
+ for split in ["train", "test"]:
162
+ batch_size = args.optim.batch_size // args.optim.grad_acc
163
+
164
+ dataloaders[split] = DataLoader(
165
+ dataset[split],
166
+ batch_size=batch_size,
167
+ num_workers=args.dataloader.num_workers,
168
+ pin_memory=True,
169
+ drop_last=False,
170
+ persistent_workers=args.dataloader.num_workers > 0,
171
+ worker_init_fn=worker_init_fn,
172
+ )
173
+
174
+ return dataloaders["train"], dataloaders["test"]
175
+
176
+
177
+ def worker_init_fn(worker_id: int) -> None:
178
+ """
179
+ Give each dataloader a unique slice of the full dataset.
180
+ """
181
+ worker_info = torch.utils.data.get_worker_info()
182
+ dataset = worker_info.dataset # the dataset copy in this worker process
183
+ overall_start = dataset.start
184
+ overall_end = dataset.end
185
+ # configure the dataset to only process the split workload
186
+ per_worker = int(
187
+ np.ceil((overall_end - overall_start) / float(worker_info.num_workers)),
188
+ )
189
+ dataset.start = overall_start + worker_id * per_worker
190
+ dataset.end = min(dataset.start + per_worker, overall_end)
classifier/libs/utils/routed_pickle.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from typing import Dict
3
+
4
+
5
+ class Unpickler(pickle.Unpickler):
6
+ load_module_mapping: Dict[str, str] = {
7
+ 'osuT5.tokenizer.event': 'osuT5.osuT5.event',
8
+ 'libs.tokenizer.event': 'classifier.libs.tokenizer.event',
9
+ 'libs.tokenizer.tokenizer': 'classifier.libs.tokenizer.tokenizer',
10
+ 'osuT5.event': 'osuT5.osuT5.event',
11
+ 'libs.event': 'classifier.libs.tokenizer.event',
12
+ 'libs.tokenizer': 'classifier.libs.tokenizer.tokenizer',
13
+ }
14
+
15
+ def find_class(self, mod_name, name):
16
+ mod_name = self.load_module_mapping.get(mod_name, mod_name)
17
+ return super().find_class(mod_name, name)
classifier/test.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+ import lightning
3
+ import torch
4
+ from omegaconf import DictConfig
5
+
6
+ from classifier.libs.utils import load_ckpt
7
+ from libs import (
8
+ get_dataloaders,
9
+ )
10
+
11
+ torch.set_float32_matmul_precision('high')
12
+
13
+
14
+ @hydra.main(config_path="configs", config_name="train_v1", version_base="1.1")
15
+ def main(args: DictConfig):
16
+ model, model_args, tokenizer = load_ckpt(args.checkpoint_path, route_pickle=False)
17
+
18
+ _, val_dataloader = get_dataloaders(tokenizer, args)
19
+
20
+ if args.compile:
21
+ model.model = torch.compile(model.model)
22
+
23
+ trainer = lightning.Trainer(
24
+ accelerator=args.device,
25
+ precision=args.precision,
26
+ )
27
+
28
+ trainer.test(model, val_dataloader)
29
+
30
+
31
+ if __name__ == "__main__":
32
+ main()
classifier/train.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import hydra
4
+ import lightning
5
+ import torch
6
+ from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
7
+ from lightning.pytorch.loggers import WandbLogger
8
+ from omegaconf import DictConfig
9
+
10
+ from libs import (
11
+ get_tokenizer,
12
+ get_dataloaders,
13
+ )
14
+ from libs.model.model import OsuClassifier
15
+ from libs.utils.model_utils import LitOsuClassifier
16
+ torch.set_float32_matmul_precision('high')
17
+
18
+
19
+ def load_old_model(path: str, model: OsuClassifier):
20
+ ckpt_path = Path(path)
21
+ model_state = torch.load(ckpt_path / "pytorch_model.bin", weights_only=True)
22
+
23
+ ignore_list = [
24
+ "transformer.model.decoder.embed_tokens.weight",
25
+ "transformer.model.decoder.embed_positions.weight",
26
+ "decoder_embedder.weight",
27
+ "transformer.proj_out.weight",
28
+ "loss_fn.weight",
29
+ ]
30
+ fixed_model_state = {}
31
+
32
+ for k, v in model_state.items():
33
+ if k in ignore_list:
34
+ continue
35
+ if k.startswith("transformer.model."):
36
+ fixed_model_state["transformer." + k[18:]] = v
37
+ else:
38
+ fixed_model_state[k] = v
39
+
40
+ model.load_state_dict(fixed_model_state, strict=False)
41
+
42
+
43
+ @hydra.main(config_path="configs", config_name="train_v1", version_base="1.1")
44
+ def main(args: DictConfig):
45
+ wandb_logger = WandbLogger(
46
+ project="osu-classifier",
47
+ entity="mappingtools",
48
+ job_type="training",
49
+ offline=args.logging.mode == "offline",
50
+ log_model="all" if args.logging.mode == "online" else False,
51
+ )
52
+
53
+ tokenizer = get_tokenizer(args)
54
+ train_dataloader, val_dataloader = get_dataloaders(tokenizer, args)
55
+
56
+ model = LitOsuClassifier(args, tokenizer)
57
+
58
+ if args.pretrained_path:
59
+ load_old_model(args.pretrained_path, model.model)
60
+
61
+ if args.compile:
62
+ model.model = torch.compile(model.model)
63
+
64
+ checkpoint_callback = ModelCheckpoint(every_n_train_steps=args.checkpoint.every_steps, save_top_k=2, monitor="val_loss")
65
+ lr_monitor = LearningRateMonitor(logging_interval="step")
66
+ trainer = lightning.Trainer(
67
+ accelerator=args.device,
68
+ precision=args.precision,
69
+ logger=wandb_logger,
70
+ max_steps=args.optim.total_steps,
71
+ accumulate_grad_batches=args.optim.grad_acc,
72
+ gradient_clip_val=args.optim.grad_clip,
73
+ val_check_interval=args.eval.every_steps,
74
+ log_every_n_steps=args.logging.every_steps,
75
+ callbacks=[checkpoint_callback, lr_monitor],
76
+ )
77
+ trainer.fit(model, train_dataloader, val_dataloader)
78
+ trainer.save_checkpoint("final.ckpt")
79
+
80
+
81
+ if __name__ == "__main__":
82
+ main()
cli_inference.sh ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Mapperatorinator CLI - Interactive Inference Script
4
+ # Based on web-ui.py functionality
5
+
6
+ set -e # Exit on error
7
+
8
+ # Colors for better UI
9
+ RED='\033[0;31m'
10
+ GREEN='\033[0;32m'
11
+ YELLOW='\033[1;33m'
12
+ BLUE='\033[0;34m'
13
+ PURPLE='\033[0;35m'
14
+ CYAN='\033[0;36m'
15
+ NC='\033[0m' # No Color
16
+
17
+ # Function to print colored text
18
+ print_color() {
19
+ local color=$1
20
+ local text=$2
21
+ echo -e "${color}${text}${NC}"
22
+ }
23
+
24
+ # Function to print section headers
25
+ print_header() {
26
+ echo
27
+ print_color $CYAN "======================================"
28
+ print_color $CYAN "$1"
29
+ print_color $CYAN "======================================"
30
+ echo
31
+ }
32
+
33
+ # Function to prompt for input with default value
34
+ prompt_input() {
35
+ local prompt=$1
36
+ local default=$2
37
+ local var_name=$3
38
+
39
+ if [ -n "$default" ]; then
40
+ read -e -p "$(print_color $GREEN "$prompt") [default: $default]: " input
41
+ if [ -z "$input" ]; then
42
+ input="$default"
43
+ fi
44
+ else
45
+ read -e -p "$(print_color $GREEN "$prompt"): " input
46
+ fi
47
+
48
+ eval "$var_name='$input'"
49
+ }
50
+
51
+ # Function to prompt for yes/no
52
+ prompt_yn() {
53
+ local prompt=$1
54
+ local default=$2
55
+ local var_name=$3
56
+
57
+ while true; do
58
+ if [ "$default" = "y" ]; then
59
+ read -p "$(print_color $GREEN "$prompt") [Y/n]: " yn
60
+ yn=${yn:-y}
61
+ else
62
+ read -p "$(print_color $GREEN "$prompt") [y/N]: " yn
63
+ yn=${yn:-n}
64
+ fi
65
+
66
+ case $yn in
67
+ [Yy]* ) eval "$var_name=true"; break;;
68
+ [Nn]* ) eval "$var_name=false"; break;;
69
+ * ) echo "Please answer yes or no.";;
70
+ esac
71
+ done
72
+ }
73
+
74
+ # Function to prompt for multiple choice
75
+ prompt_choice() {
76
+ local prompt=$1
77
+ local var_name=$2
78
+ shift 2
79
+ local options=("$@")
80
+
81
+ while true; do
82
+ print_color $GREEN "$prompt"
83
+ for i in "${!options[@]}"; do
84
+ echo " $((i+1))) ${options[i]}"
85
+ done
86
+ read -p "Select option (1-${#options[@]}): " choice
87
+
88
+ if [[ "$choice" =~ ^[0-9]+$ ]] && [ "$choice" -ge 1 ] && [ "$choice" -le "${#options[@]}" ]; then
89
+ eval "$var_name='${options[$((choice-1))]}'"
90
+ break
91
+ else
92
+ print_color $RED "Invalid choice. Please select 1-${#options[@]}."
93
+ fi
94
+ done
95
+ }
96
+
97
+ # Function to prompt for multiple selection using arrow keys and spacebar
98
+ prompt_multiselect() {
99
+ local prompt=$1
100
+ local var_name=$2
101
+ shift 2
102
+ local options=("$@")
103
+ local num_options=${#options[@]}
104
+ local selections=()
105
+ for (( i=0; i<num_options; i++ )); do
106
+ selections[i]=0
107
+ done
108
+ local current_idx=0
109
+
110
+ # Hide cursor for a cleaner UI
111
+ tput civis 2>/dev/null || true
112
+ # Ensure cursor is shown again on exit
113
+ trap 'tput cnorm; return' EXIT
114
+
115
+ # Initial draw
116
+ tput clear
117
+
118
+ while true; do
119
+ # Move cursor to top left
120
+ tput cup 0 0
121
+
122
+ echo -e "${GREEN}${prompt}${NC}"
123
+ echo "(Use UP/DOWN to navigate, SPACE to select/deselect, ENTER to confirm)"
124
+
125
+ for i in "${!options[@]}"; do
126
+ local checkbox="[ ]"
127
+ if [[ ${selections[i]} -eq 1 ]]; then
128
+ checkbox="[${GREEN}x${NC}]"
129
+ fi
130
+
131
+ if [ "$i" -eq "$current_idx" ]; then
132
+ echo -e " ${CYAN}> $checkbox ${options[i]}${NC}"
133
+ else
134
+ echo -e " $checkbox ${options[i]}"
135
+ fi
136
+ done
137
+ # Clear rest of the screen
138
+ tput ed
139
+
140
+ # Read a single keystroke.
141
+ # IFS= ensures space is read as a character, not a delimiter.
142
+ IFS= read -rsn1 key
143
+
144
+ # Handle escape sequences for arrow keys
145
+ if [[ "$key" == $'\e' ]]; then
146
+ read -rsn2 -t 0.1 key
147
+ fi
148
+
149
+ case "$key" in
150
+ '[A') # Up arrow
151
+ current_idx=$(( (current_idx - 1 + num_options) % num_options ))
152
+ ;;
153
+ '[B') # Down arrow
154
+ current_idx=$(( (current_idx + 1) % num_options ))
155
+ ;;
156
+ ' ') # Space bar
157
+ if [[ ${selections[current_idx]} -eq 1 ]]; then
158
+ selections[current_idx]=0
159
+ else
160
+ selections[current_idx]=1
161
+ fi
162
+ ;;
163
+ '') # Enter key
164
+ break
165
+ ;;
166
+ esac
167
+ done
168
+
169
+ # Show cursor again and clear the trap
170
+ tput cnorm 2>/dev/null || true
171
+ trap - EXIT
172
+
173
+ # Go back to the bottom of the screen
174
+ tput cup $(tput lines) 0
175
+ clear # Clean up the interactive menu from screen
176
+
177
+ # Collect selected options
178
+ local selected_options=()
179
+ for i in "${!options[@]}"; do
180
+ if [[ ${selections[i]} -eq 1 ]]; then
181
+ selected_options+=("${options[i]}")
182
+ fi
183
+ done
184
+
185
+ # Format the result list for Hydra/Python: '["item1", "item2"]'
186
+ if [ ${#selected_options[@]} -gt 0 ]; then
187
+ local formatted_items=""
188
+ for item in "${selected_options[@]}"; do
189
+ if [ -n "$formatted_items" ]; then
190
+ # Each item is wrapped in double quotes
191
+ formatted_items="$formatted_items,\"$item\""
192
+ else
193
+ formatted_items="\"$item\""
194
+ fi
195
+ done
196
+ # The whole list is wrapped in brackets
197
+ eval "$var_name='[$formatted_items]'"
198
+ else
199
+ # Return an empty string if nothing is selected
200
+ eval "$var_name=''"
201
+ fi
202
+ }
203
+
204
+
205
+ # Function to validate file path
206
+ validate_file() {
207
+ local file_path=$1
208
+ if [ ! -f "$file_path" ]; then
209
+ print_color $RED "File not found: $file_path"
210
+ return 1
211
+ fi
212
+ return 0
213
+ }
214
+
215
+ convert_path_if_needed() {
216
+ local input_path="$1"
217
+
218
+ # Return immediately if the path is empty
219
+ if [[ -z "$input_path" ]]; then
220
+ echo ""
221
+ return
222
+ fi
223
+
224
+ local uname_out
225
+ uname_out="$(uname -s)"
226
+
227
+ case "$uname_out" in
228
+ CYGWIN*|MINGW*|MSYS*)
229
+ cygpath -w "$input_path"
230
+ ;;
231
+ *)
232
+ echo "$input_path"
233
+ ;;
234
+ esac
235
+ }
236
+
237
+ # Main script starts here
238
+ print_color $PURPLE "╔═══════════════════════════════════════════╗"
239
+ print_color $PURPLE "║ Mapperatorinator CLI ║"
240
+ print_color $PURPLE "║ Interactive Inference Setup ║"
241
+ print_color $PURPLE "╚═══════════════════════════════════════════╝"
242
+ echo
243
+
244
+ # 2. Required Paths
245
+ print_header "Required Paths"
246
+
247
+ # Python Path
248
+ prompt_input "Python executable path" "python" python_executable
249
+
250
+ # Audio Path (Required)
251
+ while true; do
252
+ prompt_input "Audio file path (required)" "input/demo.mp3" audio_path
253
+ if [ -z "$audio_path" ]; then
254
+ print_color $RED "Audio path is required!"
255
+ continue
256
+ fi
257
+ if validate_file "$audio_path"; then
258
+ break
259
+ fi
260
+ done
261
+
262
+ # Output Path
263
+ prompt_input "Output directory path" "$(dirname "$audio_path")" output_path
264
+
265
+ # Beatmap Path (Optional)
266
+ prompt_input "Beatmap file path (optional, for in-context learning)" "" beatmap_path
267
+ if [ -n "$beatmap_path" ] && ! validate_file "$beatmap_path"; then
268
+ print_color $YELLOW "Warning: Beatmap file not found, continuing without it"
269
+ beatmap_path=""
270
+ fi
271
+
272
+ # Convert paths to Windows format if needed (for Cygwin/MinGW)
273
+ audio_path=$(convert_path_if_needed "$audio_path")
274
+ output_path=$(convert_path_if_needed "$output_path")
275
+ beatmap_path=$(convert_path_if_needed "$beatmap_path")
276
+
277
+ # 3. Basic Settings
278
+ print_header "Basic Settings"
279
+
280
+ # Model Selection
281
+ model_options=(
282
+ "v28:Mapperatorinator V28"
283
+ "v29:Mapperatorinator V29 (Supports gamemodes and descriptors)"
284
+ "v30:Mapperatorinator V30 (Best stable model)"
285
+ "v31:Mapperatorinator V31 (Slightly more accurate than V29)"
286
+ "beatheritage_v1:BeatHeritage V1 (Enhanced stability & quality)"
287
+ )
288
+
289
+ print_color $GREEN "Select Model:"
290
+ for i in "${!model_options[@]}"; do
291
+ IFS=':' read -r value desc <<< "${model_options[i]}"
292
+ echo " $((i+1))) $desc"
293
+ done
294
+
295
+ while true; do
296
+ read -p "Select model (1-${#model_options[@]}) [default: 5 - BeatHeritage V1]: " model_choice
297
+ model_choice=${model_choice:-5}
298
+
299
+ if [[ "$model_choice" =~ ^[1-5]$ ]]; then
300
+ IFS=':' read -r model_config model_desc <<< "${model_options[$((model_choice-1))]}"
301
+ print_color $BLUE "Selected: $model_desc"
302
+ break
303
+ else
304
+ print_color $RED "Invalid choice. Please select 1-${#model_options[@]}."
305
+ fi
306
+ done
307
+
308
+ # Game Mode (MODIFIED BLOCK)
309
+ gamemode_options=("osu!" "Taiko" "Catch" "Mania")
310
+ while true; do
311
+ print_color $GREEN "Game mode:"
312
+ for i in "${!gamemode_options[@]}"; do
313
+ echo " $i) ${gamemode_options[$i]}"
314
+ done
315
+ read -p "$(print_color $GREEN "Select option (0-3)") [default: 0]: " gamemode_input
316
+ # Set default value to 0 if input is empty
317
+ gamemode=${gamemode_input:-0}
318
+
319
+ if [[ "$gamemode" =~ ^[0-3]$ ]]; then
320
+ break
321
+ else
322
+ print_color $RED "Invalid choice. Please select a number between 0 and 3."
323
+ echo # Add a blank line for spacing before re-prompting
324
+ fi
325
+ done
326
+
327
+ # Difficulty
328
+ prompt_input "Difficulty (1.0-10.0)" "5.5" difficulty
329
+
330
+ # Year
331
+ # default is 2023, and 2007-2023 are valid years
332
+ prompt_input "Year" "2023" year
333
+ if ! [[ "$year" =~ ^(200[7-9]|201[0-9]|202[0-3])$ ]]; then
334
+ print_color $RED "Invalid year! Year must be between 2007 and 2023. Defaulting to 2023."
335
+ year=2023
336
+ fi
337
+
338
+ # 4. Advanced Settings (Optional)
339
+ print_header "Advanced Settings (Optional - Press Enter to skip)"
340
+ print_color $BLUE "Difficulty Settings:"
341
+ prompt_input "HP Drain Rate (0-10)" "" hp_drain_rate
342
+ prompt_input "Circle Size (0-10)" "" circle_size
343
+ prompt_input "Overall Difficulty (0-10)" "" overall_difficulty
344
+ prompt_input "Approach Rate (0-10)" "" approach_rate
345
+ print_color $BLUE "Slider Settings:"
346
+ prompt_input "Slider Multiplier" "" slider_multiplier
347
+ prompt_input "Slider Tick Rate" "" slider_tick_rate
348
+ if [ "$gamemode" -eq 3 ]; then
349
+ print_color $BLUE "Mania Settings:"
350
+ prompt_input "Key Count" "" keycount
351
+ prompt_input "Hold Note Ratio (0-1)" "" hold_note_ratio
352
+ prompt_input "Scroll Speed Ratio" "" scroll_speed_ratio
353
+ fi
354
+ print_color $BLUE "Generation Settings:"
355
+ prompt_input "CFG Scale (1-20)" "" cfg_scale
356
+ prompt_input "Temperature (0-2)" "" temperature
357
+ prompt_input "Top P (0-1)" "" top_p
358
+ prompt_input "Seed (random if empty)" "" seed
359
+ prompt_input "Mapper ID" "" mapper_id
360
+ print_color $BLUE "Timing Settings:"
361
+ prompt_input "Start Time (seconds)" "" start_time
362
+ prompt_input "End Time (seconds)" "" end_time
363
+
364
+ # 5. Boolean Options
365
+ print_header "Export & Processing Options"
366
+ prompt_yn "Export as .osz file?" "n" export_osz
367
+ prompt_yn "Add to existing beatmap?" "n" add_to_beatmap
368
+ prompt_yn "Add hitsounds?" "n" hitsounded
369
+ prompt_yn "Use super timing analysis?" "n" super_timing
370
+
371
+ # 6. Descriptors
372
+ print_header "Style Descriptors"
373
+
374
+ # Positive descriptors with interactive multi-select
375
+ descriptor_options=("jump aim" "stream" "tech" "aim" "speed" "flow" "clean" "complex" "simple" "modern" "classic" "spaced" "stacked")
376
+ prompt_multiselect "Positive descriptors (describe desired mapping style):" descriptors "${descriptor_options[@]}"
377
+
378
+ # Negative descriptors with interactive multi-select
379
+ prompt_multiselect "Negative descriptors (styles to avoid):" negative_descriptors "${descriptor_options[@]}"
380
+
381
+ # In-context options (only if beatmap is provided)
382
+ if [ -n "$beatmap_path" ]; then
383
+ print_header "In-Context Learning Options"
384
+ context_options_list=("timing" "patterns" "structure" "style")
385
+ prompt_multiselect "In-context learning aspects:" in_context_options "${context_options_list[@]}"
386
+ fi
387
+
388
+
389
+ # 7. Build and Execute Command
390
+ print_header "Command Generation"
391
+
392
+ # Start building the command
393
+ cmd_args=("$python_executable" "inference.py" "-cn" "$model_config")
394
+
395
+ # Helper function to add argument. Wraps value in single quotes.
396
+ add_arg() {
397
+ local key=$1
398
+ local value=$2
399
+ if [ -n "$value" ]; then
400
+ # This format 'key=value' is robust for Hydra, even with complex values
401
+ # like lists represented as strings: descriptors='["item1", "item2"]'
402
+ cmd_args+=("${key}=${value}") # Removed extra quotes for direct execution
403
+ fi
404
+ }
405
+
406
+ # Helper function to add boolean argument
407
+ add_bool_arg() {
408
+ local key=$1
409
+ local value=$2
410
+ if [ "$value" = "true" ]; then
411
+ cmd_args+=("${key}=true")
412
+ else
413
+ cmd_args+=("${key}=false")
414
+ fi
415
+ }
416
+
417
+ # Add all arguments
418
+ add_arg "audio_path" "'$audio_path'"
419
+ add_arg "output_path" "'$output_path'"
420
+ add_arg "beatmap_path" "'$beatmap_path'"
421
+ add_arg "gamemode" "$gamemode"
422
+ add_arg "difficulty" "$difficulty"
423
+ add_arg "year" "$year"
424
+
425
+ # Optional numeric parameters
426
+ add_arg "hp_drain_rate" "$hp_drain_rate"
427
+ add_arg "circle_size" "$circle_size"
428
+ add_arg "overall_difficulty" "$overall_difficulty"
429
+ add_arg "approach_rate" "$approach_rate"
430
+ add_arg "slider_multiplier" "$slider_multiplier"
431
+ add_arg "slider_tick_rate" "$slider_tick_rate"
432
+ add_arg "keycount" "$keycount"
433
+ add_arg "hold_note_ratio" "$hold_note_ratio"
434
+ add_arg "scroll_speed_ratio" "$scroll_speed_ratio"
435
+ add_arg "cfg_scale" "$cfg_scale"
436
+ add_arg "temperature" "$temperature"
437
+ add_arg "top_p" "$top_p"
438
+ add_arg "seed" "$seed"
439
+ add_arg "mapper_id" "$mapper_id"
440
+ add_arg "start_time" "$start_time"
441
+ add_arg "end_time" "$end_time"
442
+
443
+ # List parameters (now correctly quoted)
444
+ add_arg "descriptors" "$descriptors"
445
+ add_arg "negative_descriptors" "$negative_descriptors"
446
+ add_arg "in_context" "$in_context_options"
447
+
448
+ # Boolean parameters
449
+ add_bool_arg "export_osz" "$export_osz"
450
+ add_bool_arg "add_to_beatmap" "$add_to_beatmap"
451
+ add_bool_arg "hitsounded" "$hitsounded"
452
+ add_bool_arg "super_timing" "$super_timing"
453
+
454
+
455
+ # Display the command
456
+ print_color $YELLOW "Generated command:"
457
+ echo
458
+ # Use printf for safer printing of arguments
459
+ printf "%s " "${cmd_args[@]}"
460
+ echo
461
+ echo
462
+
463
+ # Ask for confirmation
464
+ prompt_yn "Execute this command?" "y" execute_cmd
465
+
466
+ if [ "$execute_cmd" = "true" ]; then
467
+ print_header "Executing Inference"
468
+ print_color $GREEN "Starting inference process..."
469
+ echo
470
+
471
+ # Execute the command by expanding the array. No need for eval.
472
+ "${cmd_args[@]}"
473
+
474
+ exit_code=$?
475
+ echo
476
+ if [ $exit_code -eq 0 ]; then
477
+ print_color $GREEN "✓ Inference completed successfully!"
478
+ else
479
+ print_color $RED "✗ Inference failed with exit code: $exit_code"
480
+ fi
481
+ else
482
+ print_color $YELLOW "Command generation cancelled."
483
+ echo
484
+ print_color $BLUE "You can copy and run the command manually:"
485
+ # Use printf for safer printing of arguments
486
+ printf "%s " "${cmd_args[@]}"
487
+ echo
488
+ fi
489
+
490
+ echo
491
+ print_color $PURPLE "Thank you for using Mapperatorinator CLI!"
colab/beatheritage_v1_inference.ipynb ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "metadata": {},
5
+ "cell_type": "markdown",
6
+ "source": [
7
+ "<a href=\"https://colab.research.google.com/github/hongminh54/BeatHeritage/blob/main/colab/beatheritage_v1_inference.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
8
+ "\n",
9
+ "# BeatHeritage V1 - Beatmap Generator\n",
10
+ "\n",
11
+ "An enhanced AI model for generating osu! beatmaps with improved stability and quality control.\n",
12
+ "\n",
13
+ "\n",
14
+ "### Instructions:\n",
15
+ "1. **Read and accept the rules** by clicking the checkbox in the first cell\n",
16
+ "2. **Ensure GPU runtime**: Go to __Runtime → Change Runtime Type → GPU__\n",
17
+ "3. **Execute cells in order**: Click ▶️ on each cell sequentially\n",
18
+ "4. **Upload your audio**: Choose an MP3/OGG file when prompted\n",
19
+ "5. **Configure parameters**: Adjust settings to your preference\n",
20
+ "6. **Generate beatmap**: Run the generation cell and wait for results\n"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "metadata": {},
26
+ "source": [
27
+ "#@title 🚀 Setup Environment { display-mode: \"form\" }\n",
28
+ "#@markdown ### ⚠️ Important: Please use this tool responsibly\n",
29
+ "#@markdown - Always disclose AI usage in your beatmap descriptions\n",
30
+ "#@markdown - Respect the original music artists and mappers\n",
31
+ "#@markdown - This tool is for educational and creative purposes\n",
32
+ "\n",
33
+ "i_accept_the_rules = False #@param {type:\"boolean\"}\n",
34
+ "#@markdown ☑️ **I accept the rules and will use this tool responsibly**\n",
35
+ "\n",
36
+ "import os\n",
37
+ "import sys\n",
38
+ "\n",
39
+ "if not i_accept_the_rules:\n",
40
+ " raise ValueError(\"Please read and accept the rules before proceeding!\")\n",
41
+ "\n",
42
+ "print(\"Installing BeatHeritage...\")\n",
43
+ "print(\"=\"*50)\n",
44
+ "\n",
45
+ "# Clone repository if not exists\n",
46
+ "if not os.path.exists('/content/BeatHeritage'):\n",
47
+ " !git clone -q https://github.com/hongminh54/BeatHeritage.git\n",
48
+ " print(\"✅ Repository cloned\")\n",
49
+ "else:\n",
50
+ " print(\"✅ Repository already exists\")\n",
51
+ "\n",
52
+ "%cd /content/BeatHeritage\n",
53
+ "\n",
54
+ "# Install dependencies\n",
55
+ "print(\"\\nInstalling dependencies...\")\n",
56
+ "!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n",
57
+ "!pip install -q -r requirements.txt\n",
58
+ "!apt-get install -y ffmpeg > /dev/null 2>&1\n",
59
+ "\n",
60
+ "print(\"\\nSetup complete!\")\n",
61
+ "\n",
62
+ "# Import required libraries\n",
63
+ "import warnings\n",
64
+ "warnings.filterwarnings('ignore')\n",
65
+ "\n",
66
+ "import torch\n",
67
+ "from google.colab import files\n",
68
+ "from IPython.display import display, HTML, Audio\n",
69
+ "from pathlib import Path\n",
70
+ "import json\n",
71
+ "import shlex\n",
72
+ "import subprocess\n",
73
+ "from datetime import datetime\n",
74
+ "import zipfile\n",
75
+ "\n",
76
+ "# Check GPU availability\n",
77
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
78
+ "print(f\"\\nUsing device: {device}\")\n",
79
+ "if device == 'cuda':\n",
80
+ " gpu_name = torch.cuda.get_device_name(0)\n",
81
+ " gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3\n",
82
+ " print(f\"GPU: {gpu_name}\")\n",
83
+ " print(f\"Memory: {gpu_memory:.1f} GB\")\n",
84
+ "else:\n",
85
+ " print(\"No GPU detected! Generation will be VERY slow.\")\n",
86
+ "\n",
87
+ "# Initialize global variables\n",
88
+ "audio_path = \"\"\n",
89
+ "output_path = \"/content/BeatHeritage/output\"\n",
90
+ "os.makedirs(output_path, exist_ok=True)"
91
+ ],
92
+ "outputs": [],
93
+ "execution_count": null
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "metadata": {},
98
+ "source": [
99
+ "#@title 🎵 Upload Audio File { display-mode: \"form\" }\n",
100
+ "#@markdown Upload your audio file (MP3, OGG, or WAV format)\n",
101
+ "\n",
102
+ "def upload_and_validate_audio():\n",
103
+ " \"\"\"Upload and validate audio file with proper error handling\"\"\"\n",
104
+ " global audio_path\n",
105
+ " \n",
106
+ " print(\"Please select an audio file to upload...\")\n",
107
+ " uploaded = files.upload()\n",
108
+ " \n",
109
+ " if not uploaded:\n",
110
+ " print(\"No file uploaded\")\n",
111
+ " return None\n",
112
+ " \n",
113
+ " # Get the first uploaded file\n",
114
+ " original_filename = list(uploaded.keys())[0]\n",
115
+ " \n",
116
+ " # Clean filename - remove special characters and spaces\n",
117
+ " import re\n",
118
+ " clean_filename = re.sub(r'[^a-zA-Z0-9._-]', '_', original_filename)\n",
119
+ " clean_filename = clean_filename.replace(' ', '_')\n",
120
+ " \n",
121
+ " # Ensure proper extension\n",
122
+ " if not any(clean_filename.lower().endswith(ext) for ext in ['.mp3', '.ogg', '.wav']):\n",
123
+ " print(f\"Invalid file format: {original_filename}\")\n",
124
+ " print(\"Please upload an MP3, OGG, or WAV file\")\n",
125
+ " return None\n",
126
+ " \n",
127
+ " # Save with cleaned filename\n",
128
+ " audio_path = f'/content/BeatHeritage/{clean_filename}'\n",
129
+ " \n",
130
+ " # Write the uploaded content to the new path\n",
131
+ " with open(audio_path, 'wb') as f:\n",
132
+ " f.write(uploaded[original_filename])\n",
133
+ " \n",
134
+ " print(f\"Audio uploaded successfully!\")\n",
135
+ " print(f\"Original: {original_filename}\")\n",
136
+ " print(f\"Saved as: {clean_filename}\")\n",
137
+ " print(f\"Path: {audio_path}\")\n",
138
+ " \n",
139
+ " # Display audio player\n",
140
+ " display(Audio(audio_path))\n",
141
+ " \n",
142
+ " return audio_path\n",
143
+ "\n",
144
+ "# Upload audio\n",
145
+ "audio_path = upload_and_validate_audio()\n",
146
+ "\n",
147
+ "if not audio_path:\n",
148
+ " print(\"\\n⚠Please run this cell again and upload a valid audio file\")"
149
+ ],
150
+ "outputs": [],
151
+ "execution_count": null
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "metadata": {},
156
+ "source": [
157
+ "#@title ⚙️ Configure Generation Parameters { display-mode: \"form\" }\n",
158
+ "\n",
159
+ "#@markdown ### 🎯 Basic Settings\n",
160
+ "#@markdown ---\n",
161
+ "#@markdown Choose the AI model version to use:\n",
162
+ "model_version = \"BeatHeritage V1 (Enhanced)\" #@param [\"BeatHeritage V1 (Enhanced)\", \"Mapperatorinator V30\", \"Mapperatorinator V29\", \"Mapperatorinator V28\"]\n",
163
+ "\n",
164
+ "#@markdown Select the game mode for your beatmap:\n",
165
+ "gamemode = \"Standard\" #@param [\"Standard\", \"Taiko\", \"Catch the Beat\", \"Mania\"]\n",
166
+ "\n",
167
+ "#@markdown Target difficulty (★ rating):\n",
168
+ "difficulty = 5.5 #@param {type:\"slider\", min:1, max:10, step:0.1}\n",
169
+ "\n",
170
+ "#@markdown ### 🎨 Style Configuration\n",
171
+ "#@markdown ---\n",
172
+ "#@markdown Primary mapping style descriptor:\n",
173
+ "descriptor_1 = \"clean\" #@param [\"clean\", \"tech\", \"jump aim\", \"stream\", \"aim\", \"speed\", \"flow\", \"complex\", \"simple\", \"modern\", \"classic\", \"slider tech\", \"alt\", \"precision\", \"stamina\"]\n",
174
+ "\n",
175
+ "#@markdown Secondary style descriptor (optional):\n",
176
+ "descriptor_2 = \"\" #@param [\"\", \"clean\", \"tech\", \"jump aim\", \"stream\", \"aim\", \"speed\", \"flow\", \"complex\", \"simple\", \"modern\", \"classic\", \"slider tech\", \"alt\", \"precision\", \"stamina\"]\n",
177
+ "\n",
178
+ "#@markdown ### 🔧 Advanced Parameters\n",
179
+ "#@markdown ---\n",
180
+ "#@markdown Generation temperature (lower = more conservative):\n",
181
+ "temperature = 0.85 #@param {type:\"slider\", min:0.1, max:2.0, step:0.05}\n",
182
+ "\n",
183
+ "#@markdown Top-p sampling (nucleus sampling):\n",
184
+ "top_p = 0.92 #@param {type:\"slider\", min:0.1, max:1.0, step:0.01}\n",
185
+ "\n",
186
+ "#@markdown Classifier-free guidance scale:\n",
187
+ "cfg_scale = 7.5 #@param {type:\"slider\", min:1.0, max:20.0, step:0.5}\n",
188
+ "\n",
189
+ "#@markdown ### 📊 Quality Control (BeatHeritage V1)\n",
190
+ "#@markdown ---\n",
191
+ "enable_auto_correction = True #@param {type:\"boolean\"}\n",
192
+ "enable_flow_optimization = True #@param {type:\"boolean\"}\n",
193
+ "enable_pattern_variety = True #@param {type:\"boolean\"}\n",
194
+ "\n",
195
+ "#@markdown ### 🎯 Export Options\n",
196
+ "#@markdown ---\n",
197
+ "super_timing = False #@param {type:\"boolean\"}\n",
198
+ "#@markdown Enable for songs with variable BPM (slower generation)\n",
199
+ "\n",
200
+ "export_osz = True #@param {type:\"boolean\"}\n",
201
+ "#@markdown Export as .osz package (includes audio)\n",
202
+ "\n",
203
+ "# Map model names to config names\n",
204
+ "model_configs = {\n",
205
+ " \"BeatHeritage V1 (Enhanced)\": \"beatheritage_v1\",\n",
206
+ " \"Mapperatorinator V30\": \"v30\",\n",
207
+ " \"Mapperatorinator V29\": \"v29\",\n",
208
+ " \"Mapperatorinator V28\": \"v28\"\n",
209
+ "}\n",
210
+ "\n",
211
+ "# Map gamemode names to indices\n",
212
+ "gamemode_indices = {\n",
213
+ " \"Standard\": 0,\n",
214
+ " \"Taiko\": 1,\n",
215
+ " \"Catch the Beat\": 2,\n",
216
+ " \"Mania\": 3\n",
217
+ "}\n",
218
+ "\n",
219
+ "selected_model = model_configs[model_version]\n",
220
+ "selected_gamemode = gamemode_indices[gamemode]\n",
221
+ "\n",
222
+ "# Build descriptor list\n",
223
+ "descriptors = [d for d in [descriptor_1, descriptor_2] if d]\n",
224
+ "\n",
225
+ "# Display configuration summary\n",
226
+ "print(\"Configuration Summary\")\n",
227
+ "print(\"=\"*50)\n",
228
+ "print(f\"Model: {model_version}\")\n",
229
+ "print(f\"Game Mode: {gamemode}\")\n",
230
+ "print(f\"Difficulty: {difficulty}★\")\n",
231
+ "print(f\"Style: {', '.join(descriptors) if descriptors else 'Default'}\")\n",
232
+ "print(f\"Temperature: {temperature}\")\n",
233
+ "print(f\"Top-p: {top_p}\")\n",
234
+ "print(f\"CFG Scale: {cfg_scale}\")\n",
235
+ "\n",
236
+ "if selected_model == \"beatheritage_v1\":\n",
237
+ " print(\"\\nBeatHeritage V1 Features:\")\n",
238
+ " if enable_auto_correction:\n",
239
+ " print(\" ✓ Auto-correction enabled\")\n",
240
+ " if enable_flow_optimization:\n",
241
+ " print(\" ✓ Flow optimization enabled\")\n",
242
+ " if enable_pattern_variety:\n",
243
+ " print(\" ✓ Pattern variety enabled\")\n",
244
+ "\n",
245
+ "if super_timing:\n",
246
+ " print(\"\\nSuper timing enabled (for variable BPM)\")\n",
247
+ "\n",
248
+ "print(\"\\nConfiguration ready!\")"
249
+ ],
250
+ "outputs": [],
251
+ "execution_count": null
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "metadata": {},
256
+ "source": [
257
+ "#@title 🎮 Generate Beatmap { display-mode: \"form\" }\n",
258
+ "#@markdown Click the play button to start generation. This may take a few minutes depending on song length.\n",
259
+ "\n",
260
+ "def generate_beatmap():\n",
261
+ " \"\"\"Generate beatmap with proper error handling and progress tracking\"\"\"\n",
262
+ " \n",
263
+ " if not audio_path or not os.path.exists(audio_path):\n",
264
+ " print(\"Error: No audio file found!\")\n",
265
+ " print(\"Please upload an audio file first.\")\n",
266
+ " return None\n",
267
+ " \n",
268
+ " print(\"Starting beatmap generation...\")\n",
269
+ " print(\"=\"*50)\n",
270
+ " print(f\"Audio: {os.path.basename(audio_path)}\")\n",
271
+ " print(f\"Model: {model_version}\")\n",
272
+ " print(f\"Mode: {gamemode}\")\n",
273
+ " print(f\"Difficulty: {difficulty}★\")\n",
274
+ " print(\"=\"*50)\n",
275
+ " print()\n",
276
+ " \n",
277
+ " # Build command with proper escaping\n",
278
+ " cmd = [\n",
279
+ " 'python', 'inference.py',\n",
280
+ " '-cn', selected_model,\n",
281
+ " f'audio_path={shlex.quote(audio_path)}',\n",
282
+ " f'output_path={shlex.quote(output_path)}',\n",
283
+ " f'gamemode={selected_gamemode}',\n",
284
+ " f'difficulty={difficulty}',\n",
285
+ " f'temperature={temperature}',\n",
286
+ " f'top_p={top_p}',\n",
287
+ " f'cfg_scale={cfg_scale}',\n",
288
+ " f'super_timing={str(super_timing).lower()}',\n",
289
+ " f'export_osz={str(export_osz).lower()}',\n",
290
+ " ]\n",
291
+ " \n",
292
+ " # Add descriptors if specified\n",
293
+ " if descriptors:\n",
294
+ " desc_str = json.dumps(descriptors)\n",
295
+ " cmd.append(f'descriptors={shlex.quote(desc_str)}')\n",
296
+ " \n",
297
+ " # Add BeatHeritage V1 specific features\n",
298
+ " if selected_model == \"beatheritage_v1\":\n",
299
+ " if enable_auto_correction:\n",
300
+ " cmd.append('quality_control.enable_auto_correction=true')\n",
301
+ " if enable_flow_optimization:\n",
302
+ " cmd.append('quality_control.enable_flow_optimization=true')\n",
303
+ " if enable_pattern_variety:\n",
304
+ " cmd.append('advanced_features.enable_pattern_variety=true')\n",
305
+ " \n",
306
+ " # Always enable these for V1\n",
307
+ " cmd.extend([\n",
308
+ " 'advanced_features.enable_context_aware_generation=true',\n",
309
+ " 'advanced_features.enable_style_preservation=true',\n",
310
+ " 'generate_positions=true',\n",
311
+ " 'position_refinement=true'\n",
312
+ " ])\n",
313
+ " \n",
314
+ " # Execute command\n",
315
+ " try:\n",
316
+ " print(\"⏳ Generating beatmap... (this may take several minutes)\\n\")\n",
317
+ " \n",
318
+ " # Run the command\n",
319
+ " process = subprocess.Popen(\n",
320
+ " cmd,\n",
321
+ " stdout=subprocess.PIPE,\n",
322
+ " stderr=subprocess.STDOUT,\n",
323
+ " text=True,\n",
324
+ " bufsize=1,\n",
325
+ " universal_newlines=True\n",
326
+ " )\n",
327
+ " \n",
328
+ " # Stream output in real-time\n",
329
+ " for line in process.stdout:\n",
330
+ " print(line, end='')\n",
331
+ " \n",
332
+ " # Wait for completion\n",
333
+ " return_code = process.wait()\n",
334
+ " \n",
335
+ " if return_code == 0:\n",
336
+ " print(\"\\n\" + \"=\"*50)\n",
337
+ " print(\"Beatmap generation complete!\")\n",
338
+ " \n",
339
+ " # List generated files\n",
340
+ " generated_files = list(Path(output_path).glob('*'))\n",
341
+ " if generated_files:\n",
342
+ " print(f\"\\nGenerated {len(generated_files)} file(s):\")\n",
343
+ " for file in generated_files:\n",
344
+ " size_mb = file.stat().st_size / (1024 * 1024)\n",
345
+ " print(f\" • {file.name} ({size_mb:.2f} MB)\")\n",
346
+ " \n",
347
+ " return generated_files\n",
348
+ " else:\n",
349
+ " print(f\"\\nGeneration failed with error code: {return_code}\")\n",
350
+ " return None\n",
351
+ " \n",
352
+ " except Exception as e:\n",
353
+ " print(f\"\\nError during generation: {str(e)}\")\n",
354
+ " print(\"\\nTroubleshooting tips:\")\n",
355
+ " print(\"1. Ensure the audio file is valid\")\n",
356
+ " print(\"2. Check if GPU memory is sufficient\")\n",
357
+ " print(\"3. Try reducing temperature or cfg_scale\")\n",
358
+ " print(\"4. Disable super_timing if enabled\")\n",
359
+ " return None\n",
360
+ "\n",
361
+ "# Run generation\n",
362
+ "generated_files = generate_beatmap()"
363
+ ],
364
+ "outputs": [],
365
+ "execution_count": null
366
+ },
367
+ {
368
+ "cell_type": "code",
369
+ "metadata": {},
370
+ "source": [
371
+ "#@title 📥 Download Generated Files { display-mode: \"form\" }\n",
372
+ "#@markdown Download your generated beatmap files\n",
373
+ "\n",
374
+ "def download_results():\n",
375
+ " \"\"\"Package and download generated beatmap files\"\"\"\n",
376
+ " \n",
377
+ " output_files = list(Path(output_path).glob('*'))\n",
378
+ " \n",
379
+ " if not output_files:\n",
380
+ " print(\"No files to download\")\n",
381
+ " print(\"Please generate a beatmap first.\")\n",
382
+ " return\n",
383
+ " \n",
384
+ " print(\"Preparing files for download...\")\n",
385
+ " print(\"=\"*50)\n",
386
+ " \n",
387
+ " # Create timestamp for unique naming\n",
388
+ " timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')\n",
389
+ " \n",
390
+ " # Check if we have .osz files\n",
391
+ " osz_files = [f for f in output_files if f.suffix == '.osz']\n",
392
+ " osu_files = [f for f in output_files if f.suffix == '.osu']\n",
393
+ " \n",
394
+ " # Download .osz files directly if available\n",
395
+ " if osz_files:\n",
396
+ " for osz_file in osz_files:\n",
397
+ " print(f\"\\n📥 Downloading: {osz_file.name}\")\n",
398
+ " files.download(str(osz_file))\n",
399
+ " \n",
400
+ " # Download .osu files\n",
401
+ " elif osu_files:\n",
402
+ " if len(osu_files) == 1:\n",
403
+ " # Single file - download directly\n",
404
+ " osu_file = osu_files[0]\n",
405
+ " print(f\"\\n📥 Downloading: {osu_file.name}\")\n",
406
+ " files.download(str(osu_file))\n",
407
+ " else:\n",
408
+ " # Multiple files - create zip\n",
409
+ " zip_name = f'beatheritage_{gamemode.lower()}_{timestamp}.zip'\n",
410
+ " zip_path = f'/content/{zip_name}'\n",
411
+ " \n",
412
+ " with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:\n",
413
+ " for file in output_files:\n",
414
+ " zipf.write(file, file.name)\n",
415
+ " print(f\" • Added: {file.name}\")\n",
416
+ " \n",
417
+ " print(f\"\\nDownloading: {zip_name}\")\n",
418
+ " files.download(zip_path)\n",
419
+ " \n",
420
+ " # Also handle other files\n",
421
+ " other_files = [f for f in output_files if f.suffix not in ['.osz', '.osu']]\n",
422
+ " if other_files:\n",
423
+ " print(\"\\nAdditional files generated:\")\n",
424
+ " for file in other_files:\n",
425
+ " print(f\" • {file.name}\")\n",
426
+ " \n",
427
+ " print(\"\\nDownload complete!\")\n",
428
+ " print(\"\\nTips:\")\n",
429
+ " print(\"• .osz files can be opened directly in osu!\")\n",
430
+ " print(\"• .osu files should be placed in your Songs folder\")\n",
431
+ " print(\"• Press F5 in osu! to refresh after adding files\")\n",
432
+ "\n",
433
+ "# Download files\n",
434
+ "download_results()"
435
+ ],
436
+ "outputs": [],
437
+ "execution_count": null
438
+ },
439
+ {
440
+ "cell_type": "markdown",
441
+ "metadata": {},
442
+ "source": [
443
+ "---\n",
444
+ "\n",
445
+ "## Additional Information\n",
446
+ "\n",
447
+ "### Tips for Best Results:\n",
448
+ "- **Audio Quality**: Use high-quality audio files (320kbps MP3 or FLAC)\n",
449
+ "- **Difficulty Matching**: Match the difficulty rating to song intensity\n",
450
+ "- **Style Descriptors**: Choose descriptors that match the music genre\n",
451
+ "- **Variable BPM**: Enable `super_timing` for songs with tempo changes\n",
452
+ "\n",
453
+ "### Troubleshooting:\n",
454
+ "\n",
455
+ "**Out of Memory:**\n",
456
+ "- Restart runtime to clear GPU memory\n",
457
+ "- Use shorter songs or segments\n",
458
+ "- Reduce cfg_scale value\n",
459
+ "\n",
460
+ "**Poor Quality Output:**\n",
461
+ "- Lower temperature (0.7-0.8) for stability\n",
462
+ "- Increase cfg_scale (10-15) for stronger guidance\n",
463
+ "- Use more specific descriptors\n",
464
+ "\n",
465
+ "**Generation Errors:**\n",
466
+ "- Ensure audio file has no special characters\n",
467
+ "- Check GPU is enabled in runtime\n",
468
+ "- Try different model versions\n",
469
+ "\n",
470
+ "### Resources:\n",
471
+ "- [GitHub Repository](https://github.com/hongminh54/BeatHeritage)\n",
472
+ "- [Documentation](https://github.com/hongminh54/BeatHeritage/blob/main/README.md)\n",
473
+ "\n",
474
+ "### License & Credits:\n",
475
+ "- BeatHeritage V1 by hongminh54\n",
476
+ "- Based on Mapperatorinator by OliBomby\n",
477
+ "- Please credit AI usage in your beatmap descriptions\n",
478
+ "\n",
479
+ "---"
480
+ ]
481
+ }
482
+ ],
483
+ "metadata": {
484
+ "kernelspec": {
485
+ "display_name": "Python 3",
486
+ "language": "python",
487
+ "name": "python3"
488
+ },
489
+ "language_info": {
490
+ "codemirror_mode": {
491
+ "name": "ipython",
492
+ "version": 3
493
+ },
494
+ "file_extension": ".py",
495
+ "mimetype": "text/x-python",
496
+ "name": "python",
497
+ "nbconvert_exporter": "python",
498
+ "pygments_lexer": "ipython3",
499
+ "version": "3.10.0"
500
+ },
501
+ "colab": {
502
+ "provenance": [],
503
+ "gpuType": "T4",
504
+ "collapsed_sections": []
505
+ },
506
+ "accelerator": "GPU"
507
+ },
508
+ "nbformat": 4,
509
+ "nbformat_minor": 4
510
+ }
colab/classifier_classify.ipynb ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "metadata": {},
5
+ "cell_type": "markdown",
6
+ "source": [
7
+ "<a href=\"https://colab.research.google.com/github/OliBomby/Mapperatorinator/blob/main/colab/classifier_classify.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
8
+ "\n",
9
+ "# Beatmap Mapper Classification\n",
10
+ "\n",
11
+ "This notebook is an interactive demo of an osu! beatmap mapper classification model created by OliBomby. This model is capable of predicting which osu! standard ranked mapper mapped any given beatmap by looking at the style. You can use this on your own maps to see which mapper you are most similar to.\n",
12
+ "\n",
13
+ "### Instructions for running:\n",
14
+ "\n",
15
+ "* __Execute each cell in order__. Press ▶️ on the left of each cell to execute the cell.\n",
16
+ "* __Setup Environment__: run the first cell to clone the repository and install the required dependencies. You only need to run this cell once per session.\n",
17
+ "* __Upload Audio__: choose a .mp3 or .ogg file from your computer.\n",
18
+ "* __Upload Beatmap__: choose a .osu file from your computer.\n",
19
+ "* __Configure__: choose the time of the segment which the classifier should classify.\n",
20
+ "* Classify the beatmap using the __Classify Beatmap__ cell.\n"
21
+ ],
22
+ "id": "3c19902455e25588"
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "id": "initial_id",
27
+ "metadata": {
28
+ "collapsed": true
29
+ },
30
+ "source": [
31
+ "#@title Setup Environment { display-mode: \"form\" }\n",
32
+ "\n",
33
+ "!git clone https://github.com/OliBomby/Mapperatorinator.git\n",
34
+ "%cd Mapperatorinator\n",
35
+ "\n",
36
+ "!pip install hydra-core lightning nnaudio\n",
37
+ "!pip install slider git+https://github.com/OliBomby/slider.git@gedagedigedagedaoh\n",
38
+ "\n",
39
+ "from google.colab import files\n",
40
+ "from hydra import compose, initialize_config_dir\n",
41
+ "from classifier.classify import main\n",
42
+ "\n",
43
+ "input_audio = \"\"\n",
44
+ "input_beatmap = \"\""
45
+ ],
46
+ "outputs": [],
47
+ "execution_count": null
48
+ },
49
+ {
50
+ "metadata": {},
51
+ "cell_type": "code",
52
+ "source": [
53
+ "#@title Upload Audio { display-mode: \"form\" }\n",
54
+ "\n",
55
+ "def upload_audio():\n",
56
+ " data = list(files.upload().keys())\n",
57
+ " if len(data) > 1:\n",
58
+ " print('Multiple files uploaded; using only one.')\n",
59
+ " return data[0]\n",
60
+ "\n",
61
+ "input_audio = upload_audio()"
62
+ ],
63
+ "id": "624a60c5777279e7",
64
+ "outputs": [],
65
+ "execution_count": null
66
+ },
67
+ {
68
+ "metadata": {},
69
+ "cell_type": "code",
70
+ "source": [
71
+ "#@title Upload Beatmap { display-mode: \"form\" }\n",
72
+ "\n",
73
+ "def upload_beatmap():\n",
74
+ " data = list(files.upload().keys())\n",
75
+ " if len(data) > 1:\n",
76
+ " print('Multiple files uploaded; using only one.')\n",
77
+ " return data[0]\n",
78
+ "\n",
79
+ "input_beatmap = upload_beatmap()"
80
+ ],
81
+ "id": "63884394491f6664",
82
+ "outputs": [],
83
+ "execution_count": null
84
+ },
85
+ {
86
+ "metadata": {},
87
+ "cell_type": "code",
88
+ "source": [
89
+ "#@title Configure and Classify Beatmap { display-mode: \"form\" }\n",
90
+ "\n",
91
+ "# @markdown #### Input the start time in seconds of the segment to classify.\n",
92
+ "time = 5 # @param {type:\"number\"}\n",
93
+ " \n",
94
+ "# Create config\n",
95
+ "with initialize_config_dir(version_base=\"1.1\", config_dir=\"/content/Mapperatorinator/classifier/configs\"):\n",
96
+ " conf = compose(config_name=\"inference\")\n",
97
+ "\n",
98
+ "# Do inference\n",
99
+ "conf.time = time\n",
100
+ "conf.beatmap_path = input_beatmap\n",
101
+ "conf.audio_path = input_audio\n",
102
+ "conf.mappers_path = \"./datasets/beatmap_users.json\"\n",
103
+ "\n",
104
+ "main(conf)\n"
105
+ ],
106
+ "id": "166eb3e5f9398554",
107
+ "outputs": [],
108
+ "execution_count": null
109
+ }
110
+ ],
111
+ "metadata": {
112
+ "kernelspec": {
113
+ "display_name": "Python 3",
114
+ "language": "python",
115
+ "name": "python3"
116
+ },
117
+ "accelerator": "GPU",
118
+ "language_info": {
119
+ "codemirror_mode": {
120
+ "name": "ipython",
121
+ "version": 2
122
+ },
123
+ "file_extension": ".py",
124
+ "mimetype": "text/x-python",
125
+ "name": "python",
126
+ "nbconvert_exporter": "python",
127
+ "pygments_lexer": "ipython2",
128
+ "version": "2.7.6"
129
+ }
130
+ },
131
+ "nbformat": 4,
132
+ "nbformat_minor": 5
133
+ }
colab/mai_mod_inference.ipynb ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "metadata": {},
5
+ "cell_type": "markdown",
6
+ "source": [
7
+ "<a href=\"https://colab.research.google.com/github/OliBomby/Mapperatorinator/blob/main/colab/mai_mod_inference.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
8
+ "\n",
9
+ "# Beatmap Modding with MaiMod\n",
10
+ "\n",
11
+ "This notebook is an interactive demo of an AI-driven osu! Beatmap Modding Tool created by OliBomby. This model is capable of finding various faults and inconsistencies in beatmaps which other automated modding tools can not detect. Run this tool on your beatmaps to get suggestions on how to improve them.\n",
12
+ "\n",
13
+ "### Instructions for running:\n",
14
+ "\n",
15
+ "* Make sure to use a GPU runtime, click: __Runtime >> Change Runtime Type >> GPU__\n",
16
+ "* __Execute each cell in order__. Press ▶️ on the left of each cell to execute the cell.\n",
17
+ "* __Setup Environment__: run the first cell to clone the repository and install the required dependencies. You only need to run this cell once per session.\n",
18
+ "* __Upload Audio__: choose the beatmap song .mp3 or .ogg file from your computer. You can find these files in stable by using File > Open Song Folder, or in lazer by using File > Edit Externally.\n",
19
+ "* __Upload Beatmap__: choose the beatmap .osu file from your computer.\n",
20
+ "* __Generate Suggestions__ to generate suggestions for your uploaded beatmap.\n"
21
+ ],
22
+ "id": "3c19902455e25588"
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "id": "initial_id",
27
+ "metadata": {
28
+ "collapsed": true
29
+ },
30
+ "source": [
31
+ "#@title Setup Environment { display-mode: \"form\" }\n",
32
+ "#@markdown Run this cell to clone the repository and install the required dependencies. You only need to run this cell once per session.\n",
33
+ "\n",
34
+ "!git clone https://github.com/OliBomby/Mapperatorinator.git\n",
35
+ "%cd Mapperatorinator\n",
36
+ "\n",
37
+ "!pip install transformers==4.53.3\n",
38
+ "!pip install hydra-core\n",
39
+ "!pip install slider git+https://github.com/OliBomby/slider.git@gedagedigedagedaoh\n",
40
+ "\n",
41
+ "import os\n",
42
+ "from google.colab import files\n",
43
+ "from mai_mod import main\n",
44
+ "from hydra import compose, initialize_config_dir\n",
45
+ "\n",
46
+ "input_audio = \"\"\n",
47
+ "input_beatmap = \"\""
48
+ ],
49
+ "outputs": [],
50
+ "execution_count": null
51
+ },
52
+ {
53
+ "metadata": {},
54
+ "cell_type": "code",
55
+ "source": [
56
+ "#@title Upload Audio { display-mode: \"form\" }\n",
57
+ "#@markdown Run this cell to upload the song of the beatmap that you want to mod. Please upload a .mp3 or .ogg file. You can find these files in stable by using File > Open Song Folder, or in lazer by using File > Edit Externally.\n",
58
+ "\n",
59
+ "def upload_audio():\n",
60
+ " data = list(files.upload().keys())\n",
61
+ " if len(data) > 1:\n",
62
+ " print('Multiple files uploaded; using only one.')\n",
63
+ " file = data[0]\n",
64
+ " if not file.endswith('.mp3') and not file.endswith('.ogg'):\n",
65
+ " print('Invalid file format. Please upload a .mp3 or .ogg file.')\n",
66
+ " return \"\"\n",
67
+ " return data[0]\n",
68
+ "\n",
69
+ "input_audio = upload_audio()"
70
+ ],
71
+ "id": "624a60c5777279e7",
72
+ "outputs": [],
73
+ "execution_count": null
74
+ },
75
+ {
76
+ "metadata": {},
77
+ "cell_type": "code",
78
+ "source": [
79
+ "#@title Upload Beatmap { display-mode: \"form\" }\n",
80
+ "#@markdown Run this cell to upload the beatmap **.osu** file of the beatmap that you want to mod. You can find these files in stable by using File > Open Song Folder, or in lazer by using File > Edit Externally.\n",
81
+ "\n",
82
+ "def upload_beatmap():\n",
83
+ " data = list(files.upload().keys())\n",
84
+ " if len(data) > 1:\n",
85
+ " print('Multiple files uploaded; using only one.')\n",
86
+ " file = data[0]\n",
87
+ " if not file.endswith('.osu'):\n",
88
+ " print('Invalid file format. Please upload a .osu file.\\nIn stable you can find the .osu file in the song folder (File > Open Song Folder).\\nIn lazer you can find the .osu file by using File > Edit Externally.')\n",
89
+ " return \"\"\n",
90
+ " return file\n",
91
+ "\n",
92
+ "input_beatmap = upload_beatmap()"
93
+ ],
94
+ "id": "63884394491f6664",
95
+ "outputs": [],
96
+ "execution_count": null
97
+ },
98
+ {
99
+ "metadata": {},
100
+ "cell_type": "code",
101
+ "source": [
102
+ "#@title Generate Suggestions { display-mode: \"form\" }\n",
103
+ "#@markdown Run this cell to generate suggestions for your uploaded beatmap. The suggestions will be printed in the output.\n",
104
+ "\n",
105
+ "# Validate stuff\n",
106
+ "assert os.path.exists(input_beatmap), \"Please upload a beatmap.\"\n",
107
+ "assert os.path.exists(input_audio), \"Please upload an audio file.\"\n",
108
+ " \n",
109
+ "# Create config\n",
110
+ "config = \"mai_mod\"\n",
111
+ "with initialize_config_dir(version_base=\"1.1\", config_dir=\"/content/Mapperatorinator/configs\"):\n",
112
+ " conf = compose(config_name=config)\n",
113
+ "\n",
114
+ "# Do inference\n",
115
+ "conf.audio_path = input_audio\n",
116
+ "conf.beatmap_path = input_beatmap\n",
117
+ "conf.precision = \"fp32\" # For some reason AMP causes OOM in Colab\n",
118
+ "\n",
119
+ "main(conf)"
120
+ ],
121
+ "id": "166eb3e5f9398554",
122
+ "outputs": [],
123
+ "execution_count": null
124
+ }
125
+ ],
126
+ "metadata": {
127
+ "kernelspec": {
128
+ "display_name": "Python 3",
129
+ "language": "python",
130
+ "name": "python3"
131
+ },
132
+ "accelerator": "GPU",
133
+ "language_info": {
134
+ "codemirror_mode": {
135
+ "name": "ipython",
136
+ "version": 2
137
+ },
138
+ "file_extension": ".py",
139
+ "mimetype": "text/x-python",
140
+ "name": "python",
141
+ "nbconvert_exporter": "python",
142
+ "pygments_lexer": "ipython2",
143
+ "version": "2.7.6"
144
+ }
145
+ },
146
+ "nbformat": 4,
147
+ "nbformat_minor": 5
148
+ }
colab/mapperatorinator_inference.ipynb ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "metadata": {},
5
+ "cell_type": "markdown",
6
+ "source": [
7
+ "<a href=\"https://colab.research.google.com/github/hongminh54/BeatHeritage/blob/main/colab/mapperatorinator_inference.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
8
+ "\n",
9
+ "# Beatmap Generation with Mapperatorinator\n",
10
+ "\n",
11
+ "This notebook is an interactive demo of an osu! beatmap generation model created by OliBomby. This model is capable of generating hit objects, hitsounds, timing, kiai times, and SVs for all 4 gamemodes. You can upload a beatmap to give to the model as additional context or remap parts of the beatmap.\n",
12
+ "\n",
13
+ "### Instructions for running:\n",
14
+ "\n",
15
+ "* Read and accept the rules regarding using this tool by clicking the checkbox.\n",
16
+ "* Make sure to use a GPU runtime, click: __Runtime >> Change Runtime Type >> GPU__\n",
17
+ "* __Execute each cell in order__. Press ▶️ on the left of each cell to execute the cell.\n",
18
+ "* __Setup Environment__: run the first cell to clone the repository and install the required dependencies. You only need to run this cell once per session.\n",
19
+ "* __Upload Audio__: choose a .mp3 or .ogg file from your computer.\n",
20
+ "* __Upload Beatmap__: optionally choose a beatmap .osu file from your computer. You can find these files in stable by using File > Open Song Folder, or in lazer by using File > Edit Externally.\n",
21
+ "* __Configure__: choose your generation parameters to control the style of the generated beatmap.\n",
22
+ "* Generate the beatmap using the __Generate Beatmap__ cell. (it may take a few minutes depending on the length of the song)\n"
23
+ ],
24
+ "id": "3c19902455e25588"
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "id": "initial_id",
29
+ "metadata": {
30
+ "collapsed": true
31
+ },
32
+ "source": [
33
+ "#@title Setup Environment { display-mode: \"form\" }\n",
34
+ "#@markdown ### Use this tool responsibly. Always disclose the use of AI in your beatmaps. Accept the rules and run this cell.\n",
35
+ "i_accept_the_rules = False # @param {type:\"boolean\"}\n",
36
+ "\n",
37
+ "assert i_accept_the_rules, \"Read and accept the rules first!\"\n",
38
+ "\n",
39
+ "!git clone https://github.com/hongminh54/BeatHeritage.git\n",
40
+ "%cd Mapperatorinator\n",
41
+ "\n",
42
+ "!pip install transformers==4.53.3\n",
43
+ "!pip install hydra-core nnaudio\n",
44
+ "!pip install slider git+https://github.com/OliBomby/slider.git@gedagedigedagedaoh\n",
45
+ "\n",
46
+ "from google.colab import files\n",
47
+ "\n",
48
+ "import os\n",
49
+ "from hydra import compose, initialize_config_dir\n",
50
+ "from osuT5.osuT5.event import ContextType\n",
51
+ "from inference import main\n",
52
+ "\n",
53
+ "output_path = \"output\"\n",
54
+ "input_audio = \"\"\n",
55
+ "input_beatmap = \"\""
56
+ ],
57
+ "outputs": [],
58
+ "execution_count": null
59
+ },
60
+ {
61
+ "metadata": {},
62
+ "cell_type": "code",
63
+ "source": [
64
+ "#@title Upload Audio { display-mode: \"form\" }\n",
65
+ "#@markdown Run this cell to upload audio. This is the song to generate a beatmap for. Please upload a .mp3 or .ogg file.\n",
66
+ "\n",
67
+ "def upload_audio():\n",
68
+ " data = list(files.upload().keys())\n",
69
+ " if len(data) > 1:\n",
70
+ " print('Multiple files uploaded; using only one.')\n",
71
+ " file = data[0]\n",
72
+ " if not file.endswith('.mp3') and not file.endswith('.ogg'):\n",
73
+ " print('Invalid file format. Please upload a .mp3 or .ogg file.')\n",
74
+ " return \"\"\n",
75
+ " return data[0]\n",
76
+ "\n",
77
+ "input_audio = upload_audio()"
78
+ ],
79
+ "id": "624a60c5777279e7",
80
+ "outputs": [],
81
+ "execution_count": null
82
+ },
83
+ {
84
+ "metadata": {},
85
+ "cell_type": "code",
86
+ "source": [
87
+ "#@title (Optional) Upload Beatmap { display-mode: \"form\" }\n",
88
+ "#@markdown This step is required if you want to use `in_context` or `add_to_beatmap` to provide additional info to the model.\n",
89
+ "#@markdown It will also fill in any missing metadata and unknown values in the configuration using info of the reference beatmap.\n",
90
+ "#@markdown Please upload a **.osu** file. You can find the .osu file in the song folder in stable or by using File > Edit Externally in lazer.\n",
91
+ "use_reference_beatmap = False # @param {type:\"boolean\"}\n",
92
+ "\n",
93
+ "def upload_beatmap():\n",
94
+ " data = list(files.upload().keys())\n",
95
+ " if len(data) > 1:\n",
96
+ " print('Multiple files uploaded; using only one.')\n",
97
+ " file = data[0]\n",
98
+ " if not file.endswith('.osu'):\n",
99
+ " print('Invalid file format. Please upload a .osu file.\\nIn stable you can find the .osu file in the song folder (File > Open Song Folder).\\nIn lazer you can find the .osu file by using File > Edit Externally.')\n",
100
+ " return \"\"\n",
101
+ " return file\n",
102
+ "\n",
103
+ "if use_reference_beatmap:\n",
104
+ " input_beatmap = upload_beatmap()\n",
105
+ "else:\n",
106
+ " input_beatmap = \"\""
107
+ ],
108
+ "id": "63884394491f6664",
109
+ "outputs": [],
110
+ "execution_count": null
111
+ },
112
+ {
113
+ "metadata": {},
114
+ "cell_type": "code",
115
+ "source": [
116
+ "#@title Configure and Generate Beatmap { display-mode: \"form\" }\n",
117
+ "\n",
118
+ "#@markdown #### You can input -1 to leave the value unknown.\n",
119
+ "#@markdown ---\n",
120
+ "#@markdown This is the AI model to use. V30 is the most accurate model, but it does not support other gamemodes, year, descriptors, or in_context.\n",
121
+ "model = \"Mapperatorinator V30\" # @param [\"Mapperatorinator V29\", \"Mapperatorinator V30\"]\n",
122
+ "#@markdown This is the game mode to generate a beatmap for.\n",
123
+ "gamemode = \"standard\" # @param [\"standard\", \"taiko\", \"catch the beat\", \"mania\"]\n",
124
+ "#@markdown This is the Star Rating you want your beatmap to be. It might deviate from this number depending on the song intensity and other configuration.\n",
125
+ "difficulty = 5 # @param {type:\"number\"}\n",
126
+ "#@markdown This is the user ID of the ranked mapper to imitate for mapping style. You can find this in the URL of the mapper's profile.\n",
127
+ "mapper_id = -1 # @param {type:\"integer\"}\n",
128
+ "#@markdown This is the year you want the beatmap to be from. It should be in the range of 2007 to 2023.\n",
129
+ "year = 2023 # @param {type:\"integer\"}\n",
130
+ "#@markdown This is whether you want the beatmap to be hitsounded. This works only for mania mode.\n",
131
+ "hitsounded = True # @param {type:\"boolean\"}\n",
132
+ "#@markdown These are the standard difficulty parameters for the beatmap HP, OD, AR, and CS. These are the same as the ones in the editor.\n",
133
+ "hp_drain_rate = 5 # @param {type:\"number\"}\n",
134
+ "circle_size = 4 # @param {type:\"number\"}\n",
135
+ "overall_difficulty = 9 # @param {type:\"number\"}\n",
136
+ "approach_rate = 8 # @param {type:\"number\"}\n",
137
+ "slider_multiplier = 1.4 # @param {type:\"slider\", min:0.4, max:3.6, step:0.1}\n",
138
+ "slider_tick_rate = 1 # @param {type:\"number\"}\n",
139
+ "#@markdown This is the number of keys for the mania beatmap. This works only for mania mode.\n",
140
+ "keycount = 4 # @param {type:\"slider\", min:1, max:18, step:1}\n",
141
+ "#@markdown This is the ratio of hold notes to circles in the beatmap. It should be in the range [0,1]. This works only for mania mode.\n",
142
+ "hold_note_ratio = -1 # @param {type:\"number\"}\n",
143
+ "#@markdown This is the ratio of scroll speed changes to the number of notes. It should be in the range [0,1]. This works only for mania and taiko modes.\n",
144
+ "scroll_speed_ratio = -1 # @param {type:\"number\"}\n",
145
+ "#@markdown These descriptors of the beatmap. Descriptors are used to describe the style of the beatmap. All available descriptors can be found [here](https://osu.ppy.sh/wiki/en/Beatmap/Beatmap_tags).\n",
146
+ "descriptor_1 = '' # @param [\"slider only\", \"circle only\", \"collab\", \"megacollab\", \"marathon\", \"gungathon\", \"multi-song\", \"variable timing\", \"accelerating bpm\", \"time signatures\", \"storyboard\", \"storyboard gimmick\", \"keysounds\", \"download unavailable\", \"custom skin\", \"featured artist\", \"custom song\", \"style\", \"messy\", \"geometric\", \"grid snap\", \"hexgrid\", \"freeform\", \"symmetrical\", \"old-style revival\", \"clean\", \"slidershapes\", \"distance snapped\", \"iNiS-style\", \"avant-garde\", \"perfect stacks\", \"ninja spinners\", \"simple\", \"chaotic\", \"repetition\", \"progression\", \"high contrast\", \"improvisation\", \"playfield usage\", \"playfield constraint\", \"video gimmick\", \"difficulty spike\", \"low sv\", \"high sv\", \"colorhax\", \"tech\", \"slider tech\", \"complex sv\", \"reading\", \"visually dense\", \"overlap reading\", \"alt\", \"jump aim\", \"sharp aim\", \"wide aim\", \"linear aim\", \"aim control\", \"flow aim\", \"precision\", \"finger control\", \"complex snap divisors\", \"bursts\", \"streams\", \"spaced streams\", \"cutstreams\", \"stamina\", \"mapping contest\", \"tournament custom\", \"tag\", \"port\"] {allow-input: true}\n",
147
+ "descriptor_2 = '' # @param [\"slider only\", \"circle only\", \"collab\", \"megacollab\", \"marathon\", \"gungathon\", \"multi-song\", \"variable timing\", \"accelerating bpm\", \"time signatures\", \"storyboard\", \"storyboard gimmick\", \"keysounds\", \"download unavailable\", \"custom skin\", \"featured artist\", \"custom song\", \"style\", \"messy\", \"geometric\", \"grid snap\", \"hexgrid\", \"freeform\", \"symmetrical\", \"old-style revival\", \"clean\", \"slidershapes\", \"distance snapped\", \"iNiS-style\", \"avant-garde\", \"perfect stacks\", \"ninja spinners\", \"simple\", \"chaotic\", \"repetition\", \"progression\", \"high contrast\", \"improvisation\", \"playfield usage\", \"playfield constraint\", \"video gimmick\", \"difficulty spike\", \"low sv\", \"high sv\", \"colorhax\", \"tech\", \"slider tech\", \"complex sv\", \"reading\", \"visually dense\", \"overlap reading\", \"alt\", \"jump aim\", \"sharp aim\", \"wide aim\", \"linear aim\", \"aim control\", \"flow aim\", \"precision\", \"finger control\", \"complex snap divisors\", \"bursts\", \"streams\", \"spaced streams\", \"cutstreams\", \"stamina\", \"mapping contest\", \"tournament custom\", \"tag\", \"port\"] {allow-input: true}\n",
148
+ "descriptor_3 = '' # @param [\"slider only\", \"circle only\", \"collab\", \"megacollab\", \"marathon\", \"gungathon\", \"multi-song\", \"variable timing\", \"accelerating bpm\", \"time signatures\", \"storyboard\", \"storyboard gimmick\", \"keysounds\", \"download unavailable\", \"custom skin\", \"featured artist\", \"custom song\", \"style\", \"messy\", \"geometric\", \"grid snap\", \"hexgrid\", \"freeform\", \"symmetrical\", \"old-style revival\", \"clean\", \"slidershapes\", \"distance snapped\", \"iNiS-style\", \"avant-garde\", \"perfect stacks\", \"ninja spinners\", \"simple\", \"chaotic\", \"repetition\", \"progression\", \"high contrast\", \"improvisation\", \"playfield usage\", \"playfield constraint\", \"video gimmick\", \"difficulty spike\", \"low sv\", \"high sv\", \"colorhax\", \"tech\", \"slider tech\", \"complex sv\", \"reading\", \"visually dense\", \"overlap reading\", \"alt\", \"jump aim\", \"sharp aim\", \"wide aim\", \"linear aim\", \"aim control\", \"flow aim\", \"precision\", \"finger control\", \"complex snap divisors\", \"bursts\", \"streams\", \"spaced streams\", \"cutstreams\", \"stamina\", \"mapping contest\", \"tournament custom\", \"tag\", \"port\"] {allow-input: true}\n",
149
+ "#@markdown These are negative descriptors of the beatmap. Negative descriptors are used to describe what the beatmap should not have. These work only when `cfg_scale` is greater than 1.\n",
150
+ "negative_descriptor_1 = '' # @param [\"slider only\", \"circle only\", \"collab\", \"megacollab\", \"marathon\", \"gungathon\", \"multi-song\", \"variable timing\", \"accelerating bpm\", \"time signatures\", \"storyboard\", \"storyboard gimmick\", \"keysounds\", \"download unavailable\", \"custom skin\", \"featured artist\", \"custom song\", \"style\", \"messy\", \"geometric\", \"grid snap\", \"hexgrid\", \"freeform\", \"symmetrical\", \"old-style revival\", \"clean\", \"slidershapes\", \"distance snapped\", \"iNiS-style\", \"avant-garde\", \"perfect stacks\", \"ninja spinners\", \"simple\", \"chaotic\", \"repetition\", \"progression\", \"high contrast\", \"improvisation\", \"playfield usage\", \"playfield constraint\", \"video gimmick\", \"difficulty spike\", \"low sv\", \"high sv\", \"colorhax\", \"tech\", \"slider tech\", \"complex sv\", \"reading\", \"visually dense\", \"overlap reading\", \"alt\", \"jump aim\", \"sharp aim\", \"wide aim\", \"linear aim\", \"aim control\", \"flow aim\", \"precision\", \"finger control\", \"complex snap divisors\", \"bursts\", \"streams\", \"spaced streams\", \"cutstreams\", \"stamina\", \"mapping contest\", \"tournament custom\", \"tag\", \"port\"] {allow-input: true}\n",
151
+ "negative_descriptor_2 = '' # @param [\"slider only\", \"circle only\", \"collab\", \"megacollab\", \"marathon\", \"gungathon\", \"multi-song\", \"variable timing\", \"accelerating bpm\", \"time signatures\", \"storyboard\", \"storyboard gimmick\", \"keysounds\", \"download unavailable\", \"custom skin\", \"featured artist\", \"custom song\", \"style\", \"messy\", \"geometric\", \"grid snap\", \"hexgrid\", \"freeform\", \"symmetrical\", \"old-style revival\", \"clean\", \"slidershapes\", \"distance snapped\", \"iNiS-style\", \"avant-garde\", \"perfect stacks\", \"ninja spinners\", \"simple\", \"chaotic\", \"repetition\", \"progression\", \"high contrast\", \"improvisation\", \"playfield usage\", \"playfield constraint\", \"video gimmick\", \"difficulty spike\", \"low sv\", \"high sv\", \"colorhax\", \"tech\", \"slider tech\", \"complex sv\", \"reading\", \"visually dense\", \"overlap reading\", \"alt\", \"jump aim\", \"sharp aim\", \"wide aim\", \"linear aim\", \"aim control\", \"flow aim\", \"precision\", \"finger control\", \"complex snap divisors\", \"bursts\", \"streams\", \"spaced streams\", \"cutstreams\", \"stamina\", \"mapping contest\", \"tournament custom\", \"tag\", \"port\"] {allow-input: true}\n",
152
+ "negative_descriptor_3 = '' # @param [\"slider only\", \"circle only\", \"collab\", \"megacollab\", \"marathon\", \"gungathon\", \"multi-song\", \"variable timing\", \"accelerating bpm\", \"time signatures\", \"storyboard\", \"storyboard gimmick\", \"keysounds\", \"download unavailable\", \"custom skin\", \"featured artist\", \"custom song\", \"style\", \"messy\", \"geometric\", \"grid snap\", \"hexgrid\", \"freeform\", \"symmetrical\", \"old-style revival\", \"clean\", \"slidershapes\", \"distance snapped\", \"iNiS-style\", \"avant-garde\", \"perfect stacks\", \"ninja spinners\", \"simple\", \"chaotic\", \"repetition\", \"progression\", \"high contrast\", \"improvisation\", \"playfield usage\", \"playfield constraint\", \"video gimmick\", \"difficulty spike\", \"low sv\", \"high sv\", \"colorhax\", \"tech\", \"slider tech\", \"complex sv\", \"reading\", \"visually dense\", \"overlap reading\", \"alt\", \"jump aim\", \"sharp aim\", \"wide aim\", \"linear aim\", \"aim control\", \"flow aim\", \"precision\", \"finger control\", \"complex snap divisors\", \"bursts\", \"streams\", \"spaced streams\", \"cutstreams\", \"stamina\", \"mapping contest\", \"tournament custom\", \"tag\", \"port\"] {allow-input: true}\n",
153
+ "#@markdown ---\n",
154
+ "#@markdown If true, the generated beatmap will be exported as a .osz file. Otherwise, it will be exported as a .osu file.\n",
155
+ "export_osz = False # @param {type:\"boolean\"}\n",
156
+ "#@markdown If true, the generated beatmap will be added to the reference beatmap and the reference beatmap will be modified instead of creating a new beatmap. It will also continue any hit objects before the start time in the reference beatmap.\n",
157
+ "add_to_beatmap = False # @param {type:\"boolean\"}\n",
158
+ "#@markdown This is the start time of the beatmap in milliseconds. Use this to constrain the generation to a specific part of the song.\n",
159
+ "start_time = -1 # @param {type:\"integer\"}\n",
160
+ "#@markdown This is the end time of the beatmap in milliseconds. Use this to constrain the generation to a specific part of the song.\n",
161
+ "end_time = -1 # @param {type:\"integer\"}\n",
162
+ "#@markdown This is which additional information to give to the model:\n",
163
+ "#@markdown - TIMING: Give timing points to the model. This will skip the timing point generation step.\n",
164
+ "#@markdown - KIAI: Give kiai times to the model. This will skip the kiai time generation step.\n",
165
+ "#@markdown - MAP: Give hit objects to the model. This will skip the hit object generation step.\n",
166
+ "#@markdown - GD: Give hit objects of another difficulty in the same mapset to the model (can be a different game mode). It will improve the rhythm accuracy and consistency of the generated beatmap without copying the reference beatmap.\n",
167
+ "#@markdown - NO_HS: Give hit objects without hitsounds to the model. This will copy the hit objects of the reference beatmap and only add hitsounds to them.\n",
168
+ "in_context = \"[NONE]\" # @param [\"[NONE]\", \"[TIMING]\", \"[TIMING,KIAI]\", \"[TIMING,KIAI,MAP]\", \"[GD,TIMING,KIAI]\", \"[NO_HS,TIMING,KIAI]\"]\n",
169
+ "#@markdown This is the output type of the beatmap. You can choose to either generate everything or only generate timing points.\n",
170
+ "output_type = \"[MAP]\" # @param [\"[MAP]\", \"[TIMING,KIAI,MAP,SV]\", \"[TIMING]\"]\n",
171
+ "#@markdown This is the scale of the classifier-free guidance. A higher scale will make the model more likely to follow the descriptors and mapper style. A high `cfg_scale` or certain combinations of settings can produce unexpected results, so use it with caution. \n",
172
+ "cfg_scale = 1 # @param {type:\"slider\", min:1, max:5, step:0.1}\n",
173
+ "#@markdown This is the temperature of the sampling. A lower temperature will make the model more conservative and less creative. I only recommend lowering this slightly or when using `add_to_beatmap` and generating small sections.\n",
174
+ "temperature = 1 # @param {type:\"slider\", min:0, max:1, step:0.01}\n",
175
+ "#@markdown This is the random seed. Change this to sample a different beatmap with the same settings.\n",
176
+ "seed = -1 # @param {type:\"integer\"}\n",
177
+ "#@markdown ---\n",
178
+ "#@markdown If true, uses a slow and accurate timing generator. This will make the generation slower, but the timing will be more accurate.\n",
179
+ "#@markdown This is the leniency of the normal timing generator. It will allow the timing ticks to deviate from the real timing by this many milliseconds. A higher value will result in less timing points.\n",
180
+ "timing_leniency = 20 # @param {type:\"slider\", min:0, max:100, step:1}\n",
181
+ "super_timing = False # @param {type:\"boolean\"}\n",
182
+ "#@markdown This is the number of beams for beam search for the super timing generator. Higher values will result in slightly more accurate timing at the cost of speed. \n",
183
+ "timer_num_beams = 2 # @param {type:\"slider\", min:1, max:16, step:1}\n",
184
+ "#@markdown This is the number of iterations for the super timing generator. Higher values will result in slightly more accurate timing at the cost of speed.\n",
185
+ "timer_iterations = 20 # @param {type:\"slider\", min:1, max:100, step:1}\n",
186
+ "#@markdown This is the certainty threshold requirement for BPM changes in the super timing generator. Higher values will result in less BPM changes.\n",
187
+ "timer_bpm_threshold = 0.1 # @param {type:\"slider\", min:0, max:1, step:0.1}\n",
188
+ "#@markdown ---\n",
189
+ "\n",
190
+ "# Get actual parameters\n",
191
+ "a_config = model.split(' ')[-1].lower()\n",
192
+ "a_gamemode = [\"standard\", \"taiko\", \"catch the beat\", \"mania\"].index(gamemode)\n",
193
+ "a_difficulty = None if difficulty == -1 else difficulty\n",
194
+ "a_mapper_id = None if mapper_id == -1 else mapper_id\n",
195
+ "a_year = None if year == -1 else year\n",
196
+ "a_hp_drain_rate = None if hp_drain_rate == -1 else hp_drain_rate\n",
197
+ "a_circle_size = None if circle_size == -1 else circle_size\n",
198
+ "a_overall_difficulty = None if overall_difficulty == -1 else overall_difficulty\n",
199
+ "a_approach_rate = None if approach_rate == -1 else approach_rate\n",
200
+ "a_slider_multiplier = None if slider_multiplier == -1 else slider_multiplier\n",
201
+ "a_slider_tick_rate = None if slider_tick_rate == -1 else slider_tick_rate\n",
202
+ "a_hold_note_ratio = None if hold_note_ratio == -1 else hold_note_ratio\n",
203
+ "a_scroll_speed_ratio = None if scroll_speed_ratio == -1 else scroll_speed_ratio\n",
204
+ "descriptors = [d for d in [descriptor_1, descriptor_2, descriptor_3] if d != '']\n",
205
+ "negative_descriptors = [d for d in [negative_descriptor_1, negative_descriptor_2, negative_descriptor_3] if d != '']\n",
206
+ "\n",
207
+ "a_start_time = None if start_time == -1 else start_time\n",
208
+ "a_end_time = None if end_time == -1 else end_time\n",
209
+ "a_in_context = [ContextType(c.lower()) for c in in_context[1:-1].split(',')]\n",
210
+ "a_output_type = [ContextType(c.lower()) for c in output_type[1:-1].split(',')]\n",
211
+ "a_seed = None if seed == -1 else seed\n",
212
+ "\n",
213
+ "# Validate stuff\n",
214
+ "if any(c in a_in_context for c in [ContextType.TIMING, ContextType.KIAI, ContextType.MAP, ContextType.SV, ContextType.GD, ContextType.NO_HS]) or add_to_beatmap:\n",
215
+ " assert os.path.exists(input_beatmap), \"Please upload a reference beatmap.\"\n",
216
+ "assert os.path.exists(input_audio), \"Please upload an audio file.\"\n",
217
+ "if a_config == \"v30\":\n",
218
+ " assert a_gamemode == 0, \"V30 only supports standard mode.\"\n",
219
+ " if any(c in a_in_context for c in [ContextType.KIAI, ContextType.MAP, ContextType.SV]):\n",
220
+ " print(\"WARNING: V30 does not support KIAI, MAP, or SV in_context, ignoring.\")\n",
221
+ " if output_type != \"[MAP]\":\n",
222
+ " print(\"WARNING: V30 only supports [MAP] output type, setting output type to [MAP].\")\n",
223
+ " a_output_type = [ContextType.MAP]\n",
224
+ " if len(descriptors) != 0 and len(negative_descriptors) != 0:\n",
225
+ " print(\"WARNING: V30 does not support descriptors or negative descriptors, ignoring.\")\n",
226
+ " if super_timing:\n",
227
+ " print(\"WARNING: V30 does not fully support super timing, generation will be VERY slow.\")\n",
228
+ " \n",
229
+ "# Create config\n",
230
+ "with initialize_config_dir(version_base=\"1.1\", config_dir=\"/content/Mapperatorinator/configs/inference\"):\n",
231
+ " conf = compose(config_name=a_config)\n",
232
+ "\n",
233
+ "# Do inference\n",
234
+ "conf.audio_path = input_audio\n",
235
+ "conf.output_path = output_path\n",
236
+ "conf.beatmap_path = input_beatmap\n",
237
+ "conf.gamemode = a_gamemode\n",
238
+ "conf.difficulty = a_difficulty\n",
239
+ "conf.mapper_id = a_mapper_id\n",
240
+ "conf.year = a_year\n",
241
+ "conf.hitsounded = hitsounded\n",
242
+ "conf.hp_drain_rate = a_hp_drain_rate\n",
243
+ "conf.circle_size = a_circle_size\n",
244
+ "conf.overall_difficulty = a_overall_difficulty\n",
245
+ "conf.approach_rate = a_approach_rate\n",
246
+ "conf.slider_multiplier = a_slider_multiplier\n",
247
+ "conf.slider_tick_rate = a_slider_tick_rate\n",
248
+ "conf.keycount = keycount\n",
249
+ "conf.hold_note_ratio = a_hold_note_ratio\n",
250
+ "conf.scroll_speed_ratio = a_scroll_speed_ratio\n",
251
+ "conf.descriptors = descriptors\n",
252
+ "conf.negative_descriptors = negative_descriptors\n",
253
+ "conf.export_osz = export_osz\n",
254
+ "conf.add_to_beatmap = add_to_beatmap\n",
255
+ "conf.start_time = a_start_time\n",
256
+ "conf.end_time = a_end_time\n",
257
+ "conf.in_context = a_in_context\n",
258
+ "conf.output_type = a_output_type\n",
259
+ "conf.cfg_scale = cfg_scale\n",
260
+ "conf.temperature = temperature\n",
261
+ "conf.seed = a_seed\n",
262
+ "conf.timing_leniency = timing_leniency\n",
263
+ "conf.super_timing = super_timing\n",
264
+ "conf.timer_num_beams = timer_num_beams\n",
265
+ "conf.timer_iterations = timer_iterations\n",
266
+ "conf.timer_bpm_threshold = timer_bpm_threshold\n",
267
+ "\n",
268
+ "_, result_path, osz_path = main(conf)\n",
269
+ "\n",
270
+ "if osz_path is not None:\n",
271
+ " result_path = osz_path\n",
272
+ "\n",
273
+ "if conf.add_to_beatmap:\n",
274
+ " files.download(result_path)\n",
275
+ "else:\n",
276
+ " files.download(result_path)\n"
277
+ ],
278
+ "id": "166eb3e5f9398554",
279
+ "outputs": [],
280
+ "execution_count": null
281
+ }
282
+ ],
283
+ "metadata": {
284
+ "kernelspec": {
285
+ "display_name": "Python 3",
286
+ "language": "python",
287
+ "name": "python3"
288
+ },
289
+ "accelerator": "GPU",
290
+ "language_info": {
291
+ "codemirror_mode": {
292
+ "name": "ipython",
293
+ "version": 2
294
+ },
295
+ "file_extension": ".py",
296
+ "mimetype": "text/x-python",
297
+ "name": "python",
298
+ "nbconvert_exporter": "python",
299
+ "pygments_lexer": "ipython2",
300
+ "version": "2.7.6"
301
+ }
302
+ },
303
+ "nbformat": 4,
304
+ "nbformat_minor": 5
305
+ }
collate_results.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+
5
+ def get_color_for_value(value, min_val, max_val, lower_is_better=False):
6
+ """
7
+ Generates an HSL color string from red to green based on a value's
8
+ position between a min and max.
9
+
10
+ Args:
11
+ value (float): The current value.
12
+ min_val (float): The minimum value in the dataset for this metric.
13
+ max_val (float): The maximum value in the dataset for this metric.
14
+ lower_is_better (bool): If True, lower values get greener colors.
15
+
16
+ Returns:
17
+ str: An HSL color string for use in CSS.
18
+ """
19
+ # Avoid division by zero if all values are the same
20
+ if min_val == max_val:
21
+ return "hsl(120, 70%, 60%)" # Default to green
22
+
23
+ # Normalize the value to a 0-1 range
24
+ normalized = (value - min_val) / (max_val - min_val)
25
+
26
+ if lower_is_better:
27
+ # Invert the scale: 1 (best) -> 0 (worst)
28
+ hue = (1 - normalized) * 120
29
+ else:
30
+ # Standard scale: 0 (worst) -> 1 (best)
31
+ hue = normalized * 120
32
+
33
+ # Return HSL color: hue from 0 (red) to 120 (green), with fixed saturation and lightness
34
+ return f"hsl({hue:.0f}, 70%, 60%)"
35
+
36
+
37
+ def parse_log_files(root_dir):
38
+ """
39
+ Parses log files in subdirectories to extract metrics and format them
40
+ into an HTML table with colored cells.
41
+
42
+ Args:
43
+ root_dir (str): The path to the main folder containing the model subfolders.
44
+
45
+ Returns:
46
+ str: A string containing the formatted HTML table.
47
+ """
48
+ results = []
49
+ dir_pattern = re.compile(r"inference=(inference_)?([a-zA-Z0-9_-]+)")
50
+ metric_patterns = {
51
+ 'FID': re.compile(r"FID: ([\d.]+)"),
52
+ 'AR Pr.': re.compile(r"Active Rhythm Precision: ([\d.]+)"),
53
+ 'AR Re.': re.compile(r"Active Rhythm Recall: ([\d.]+)"),
54
+ 'AR F1': re.compile(r"Active Rhythm F1: ([\d.]+)"),
55
+ 'PR Pr.': re.compile(r"Passive Rhythm Precision: ([\d.]+)"),
56
+ 'PR Re.': re.compile(r"Passive Rhythm Recall: ([\d.]+)"),
57
+ 'PR F1': re.compile(r"Passive Rhythm F1: ([\d.]+)")
58
+ }
59
+
60
+ for dirpath, dirnames, filenames in os.walk(root_dir):
61
+ if dirpath == root_dir:
62
+ for dirname in dirnames:
63
+ dir_match = dir_pattern.match(dirname)
64
+ if not dir_match:
65
+ continue
66
+ model_name = dir_match.group(2)
67
+ log_file_path = os.path.join(dirpath, dirname, 'calc_fid.log')
68
+
69
+ if not os.path.exists(log_file_path):
70
+ print(f"Warning: 'calc_fid.log' not found in {dirname}")
71
+ continue
72
+
73
+ latest_metrics = {}
74
+ try:
75
+ with open(log_file_path, 'r') as f:
76
+ for line in f:
77
+ for key, pattern in metric_patterns.items():
78
+ match = pattern.search(line)
79
+ if match:
80
+ latest_metrics[key] = float(match.group(1))
81
+ except Exception as e:
82
+ print(f"Error reading {log_file_path}: {e}")
83
+ continue
84
+
85
+ if latest_metrics:
86
+ latest_metrics['Model name'] = model_name
87
+ results.append(latest_metrics)
88
+ dirnames[:] = []
89
+
90
+ if not results:
91
+ return "<p>No results found. Check if <code>root_dir</code> is correct and log files exist.</p>"
92
+
93
+ # --- Pre-calculate Min/Max for coloring ---
94
+ headers = ["Model name", "FID", "AR Pr.", "AR Re.", "AR F1", "PR Pr.", "PR Re.", "PR F1"]
95
+ min_max_vals = {}
96
+ for header in headers:
97
+ if header == "Model name":
98
+ continue
99
+ # Get all valid values for the current header
100
+ values = [res.get(header) for res in results if res.get(header) is not None]
101
+ if values:
102
+ min_max_vals[header] = {'min': min(values), 'max': max(values)}
103
+
104
+ # --- Generate HTML Table ---
105
+ html = ["<table>"]
106
+ # Header row
107
+ html.append(" <thead>")
108
+ html.append(" <tr>" + "".join([f"<th>{h}</th>" for h in headers]) + "</tr>")
109
+ html.append(" </thead>")
110
+
111
+ # Data rows
112
+ html.append(" <tbody>")
113
+ for res in sorted(results, key=lambda x: x.get('Model name', '')):
114
+ row_html = " <tr>"
115
+ for header in headers:
116
+ value = res.get(header)
117
+
118
+ if header == 'Model name':
119
+ row_html += f"<td>{res.get('Model name', 'N/A')}</td>"
120
+ continue
121
+
122
+ if value is None:
123
+ row_html += "<td>N/A</td>"
124
+ continue
125
+
126
+ # Formatting
127
+ if header == 'FID':
128
+ formatted_value = f"{value:.2f}"
129
+ lower_is_better = True
130
+ else:
131
+ formatted_value = f"{value:.3f}"
132
+ lower_is_better = False
133
+
134
+ # Get color and apply style
135
+ color = get_color_for_value(value, min_max_vals[header]['min'], min_max_vals[header]['max'],
136
+ lower_is_better)
137
+ # Added a light text shadow for better readability on bright colors
138
+ style = f"background-color: {color}; color: black; text-shadow: 0 0 5px white;"
139
+ row_html += f'<td style="{style}">{formatted_value}</td>'
140
+
141
+ row_html += "</tr>"
142
+ html.append(row_html)
143
+
144
+ html.append(" </tbody>")
145
+ html.append("</table>")
146
+
147
+ return "\n".join(html)
148
+
149
+
150
+ if __name__ == '__main__':
151
+ # --- IMPORTANT ---
152
+ # Change this to the path of your main results folder.
153
+ # You can use "." if the script is in the same parent folder as the "inference=..." folders.
154
+ logs_directory = './logs_fid/sweeps/test_3'
155
+
156
+ markdown_table = parse_log_files(logs_directory)
157
+ print(markdown_table)
158
+
compose.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: mapperatorinator
2
+ services:
3
+ mapperatorinator:
4
+ stdin_open: true
5
+ tty: true
6
+ deploy:
7
+ resources:
8
+ reservations:
9
+ devices:
10
+ - driver: nvidia
11
+ count: all
12
+ capabilities:
13
+ - gpu
14
+ volumes:
15
+ - .:/workspace/Mapperatorinator
16
+ - ../datasets:/workspace/datasets
17
+ network_mode: host
18
+ container_name: mapperatorinator_space
19
+ shm_size: 8gb
20
+ build: .
21
+ # image: my_fixed_image
22
+ command: /bin/bash
23
+ environment:
24
+ - PROJECT_PATH=/workspace/Mapperatorinator
25
+ - WANDB_API_KEY=${WANDB_API_KEY}
config.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, Optional
3
+
4
+ from hydra.core.config_store import ConfigStore
5
+ from omegaconf import MISSING
6
+
7
+ from osuT5.osuT5.config import TrainConfig
8
+ from osuT5.osuT5.tokenizer import ContextType
9
+ from osu_diffusion.config import DiffusionTrainConfig
10
+
11
+
12
+ # BeatHeritage V1 Config Sections
13
+
14
+ @dataclass
15
+ class AdvancedFeaturesConfig:
16
+ enable_context_aware_generation: bool = True
17
+ enable_style_preservation: bool = True
18
+ enable_difficulty_scaling: bool = True
19
+ enable_pattern_variety: bool = True
20
+
21
+ @dataclass
22
+ class QualityControlConfig:
23
+ min_distance_threshold: int = 20
24
+ max_overlap_ratio: float = 0.15
25
+ enable_auto_correction: bool = True
26
+ enable_flow_optimization: bool = True
27
+
28
+ @dataclass
29
+ class PerformanceConfig:
30
+ use_flash_attention: bool = False
31
+ batch_size: int = 1
32
+ max_sequence_length: int = 5120
33
+ cache_size: int = 4096
34
+
35
+ @dataclass
36
+ class MetadataConfig:
37
+ preserve_timing_points: bool = True
38
+ preserve_bookmarks: bool = True
39
+ auto_detect_kiai: bool = True
40
+ smart_hitsounding: bool = True
41
+
42
+ @dataclass
43
+ class PostprocessorConfig:
44
+ use_custom: bool = True
45
+ class_name: str = 'beatheritage_postprocessor.BeatHeritagePostprocessor'
46
+ config_class: str = 'beatheritage_postprocessor.BeatHeritageConfig'
47
+
48
+ @dataclass
49
+ class IntegrationsConfig:
50
+ mai_mod_enhanced: bool = True
51
+ fid_evaluation: bool = True
52
+ benchmark_mode: bool = False
53
+
54
+
55
+ # Default config here based on V28
56
+
57
+ @dataclass
58
+ class InferenceConfig:
59
+ model_path: str = '' # Path to trained model
60
+ audio_path: str = '' # Path to input audio
61
+ output_path: str = '' # Path to output directory
62
+ beatmap_path: str = '' # Path to .osu file to autofill metadata and use as reference
63
+
64
+ # Conditional generation settings
65
+ gamemode: Optional[int] = None # Gamemode of the beatmap
66
+ beatmap_id: Optional[int] = None # Beatmap ID to use as style
67
+ difficulty: Optional[float] = None # Difficulty star rating to map
68
+ mapper_id: Optional[int] = None # Mapper ID to use as style
69
+ year: Optional[int] = None # Year to use as style
70
+ hitsounded: Optional[bool] = None # Whether the beatmap has hitsounds
71
+ keycount: Optional[int] = None # Number of keys to use for mania
72
+ hold_note_ratio: Optional[float] = None # Ratio of how many hold notes to generate in mania
73
+ scroll_speed_ratio: Optional[float] = None # Ratio of how many scroll speed changes to generate in mania and taiko
74
+ descriptors: Optional[list[str]] = None # List of descriptors to use for style
75
+ negative_descriptors: Optional[list[str]] = None # List of descriptors to avoid when using classifier-free guidance
76
+
77
+ # Difficulty settings
78
+ hp_drain_rate: Optional[float] = None # HP drain rate (HP)
79
+ circle_size: Optional[float] = None # Circle size (CS)
80
+ overall_difficulty: Optional[float] = None # Overall difficulty (OD)
81
+ approach_rate: Optional[float] = None # Approach rate (AR)
82
+ slider_multiplier: Optional[float] = None # Multiplier for slider velocity
83
+ slider_tick_rate: Optional[float] = None # Rate of slider ticks
84
+
85
+ # Inference settings
86
+ seed: Optional[int] = None # Random seed
87
+ device: str = 'auto' # Inference device (cpu/cuda/mps/auto)
88
+ precision: str = 'fp32' # Lower precision for speed (fp32/bf16/amp)
89
+ add_to_beatmap: bool = False # Add generated content to the reference beatmap
90
+ export_osz: bool = False # Export beatmap as .osz file
91
+ start_time: Optional[int] = None # Start time of audio to generate beatmap for
92
+ end_time: Optional[int] = None # End time of audio to generate beatmap for
93
+ lookback: float = 0.5 # Fraction of audio sequence to fill with tokens from previous inference window
94
+ lookahead: float = 0.4 # Fraction of audio sequence to skip at the end of the audio window
95
+ timing_leniency: int = 20 # Number of milliseconds of error to allow for timing generation
96
+ in_context: list[ContextType] = field(default_factory=lambda: [ContextType.NONE]) # Context types of other beatmap(s)
97
+ output_type: list[ContextType] = field(default_factory=lambda: [ContextType.MAP]) # Output type (map, timing)
98
+ cfg_scale: float = 1.0 # Scale of classifier-free guidance
99
+ temperature: float = 1.0 # Sampling temperature
100
+ timing_temperature: float = 0.1 # Sampling temperature for timing
101
+ mania_column_temperature: float = 0.5 # Sampling temperature for mania columns
102
+ taiko_hit_temperature: float = 0.5 # Sampling temperature for taiko hit types
103
+ timeshift_bias: float = 0.0 # Logit bias for sampling timeshift tokens
104
+ top_p: float = 0.95 # Top-p sampling threshold
105
+ top_k: int = 0 # Top-k sampling threshold
106
+ repetition_penalty: float = 1.0 # Repetition penalty to reduce repetitive patterns
107
+ parallel: bool = False # Use parallel sampling
108
+ do_sample: bool = True # Use sampling
109
+ num_beams: int = 1 # Number of beams for beam search
110
+ super_timing: bool = False # Use super timing generator (slow but accurate timing)
111
+ timer_num_beams: int = 2 # Number of beams for beam search
112
+ timer_bpm_threshold: float = 0.7 # Threshold requirement for BPM change in timer, higher values will result in less BPM changes
113
+ timer_cfg_scale: float = 1.0 # Scale of classifier-free guidance for timer
114
+ timer_iterations: int = 20 # Number of iterations for timer
115
+ use_server: bool = True # Use server for optimized multiprocess inference
116
+ max_batch_size: int = 16 # Maximum batch size for inference (only used for parallel sampling or super timing)
117
+ resnap_events: bool = True # Resnap notes to the timing after generation
118
+ position_refinement: bool = False # Use position refinement
119
+
120
+ # Metadata settings
121
+ bpm: int = 120 # Beats per minute of input audio
122
+ offset: int = 0 # Start of beat, in miliseconds, from the beginning of input audio
123
+ title: str = '' # Song title
124
+ artist: str = '' # Song artist
125
+ creator: str = '' # Beatmap creator
126
+ version: str = '' # Beatmap version
127
+ background: Optional[str] = None # File name of background image
128
+ preview_time: int = -1 # Time in milliseconds to start previewing the song
129
+
130
+ # Diffusion settings
131
+ generate_positions: bool = True # Use diffusion to generate object positions
132
+ diff_cfg_scale: float = 1.0 # Scale of classifier-free guidance
133
+ compile: bool = False # PyTorch 2.0 optimization
134
+ pad_sequence: bool = False # Pad sequence to max_seq_len
135
+ diff_ckpt: str = '' # Path to checkpoint for diffusion model
136
+ diff_refine_ckpt: str = '' # Path to checkpoint for refining diffusion model
137
+ beatmap_idx: str = 'osu_diffusion/beatmap_idx.pickle' # Path to beatmap index
138
+ refine_iters: int = 10 # Number of refinement iterations
139
+ random_init: bool = False # Whether to initialize with random noise instead of positions generated by the previous model
140
+ timesteps: list[int] = field(default_factory=lambda: [100, 0, 0, 0, 0, 0, 0, 0, 0, 0]) # The number of timesteps we want to take from equally-sized portions of the original process
141
+ max_seq_len: int = 1024 # Maximum sequence length for diffusion
142
+ overlap_buffer: int = 128 # Buffer zone at start and end of sequence to avoid edge effects (should be less than half of max_seq_len)
143
+
144
+ # Training settings
145
+ train: TrainConfig = field(default_factory=TrainConfig) # Training settings for osuT5 model
146
+ diffusion: DiffusionTrainConfig = field(default_factory=DiffusionTrainConfig) # Training settings for diffusion model
147
+
148
+ # BeatHeritage V1 Config Sections
149
+ advanced_features: AdvancedFeaturesConfig = field(default_factory=AdvancedFeaturesConfig)
150
+ quality_control: QualityControlConfig = field(default_factory=QualityControlConfig)
151
+ performance: PerformanceConfig = field(default_factory=PerformanceConfig)
152
+ metadata: MetadataConfig = field(default_factory=MetadataConfig)
153
+ postprocessor: PostprocessorConfig = field(default_factory=PostprocessorConfig)
154
+ integrations: IntegrationsConfig = field(default_factory=IntegrationsConfig)
155
+ hydra: Any = MISSING
156
+
157
+
158
+ @dataclass
159
+ class FidConfig:
160
+ device: str = 'auto' # Inference device (cpu/cuda/mps/auto)
161
+ compile: bool = True
162
+ num_processes: int = 3
163
+ seed: int = 0
164
+
165
+ skip_generation: bool = False
166
+ fid: bool = True
167
+ rhythm_stats: bool = True
168
+
169
+ dataset_type: str = 'ors'
170
+ dataset_path: str = '/workspace/datasets/ORS16291'
171
+ dataset_start: int = 16200
172
+ dataset_end: int = 16291
173
+ gamemodes: list[int] = field(default_factory=lambda: [0]) # List of gamemodes to include in the dataset
174
+
175
+ classifier_ckpt: str = 'OliBomby/osu-classifier'
176
+ classifier_batch_size: int = 16
177
+
178
+ training_set_ids_path: Optional[str] = None # Path to training set beatmap IDs
179
+
180
+ inference: InferenceConfig = field(default_factory=InferenceConfig) # Training settings for osuT5 model
181
+ hydra: Any = MISSING
182
+
183
+
184
+ @dataclass
185
+ class MaiModConfig:
186
+ beatmap_path: str = '' # Path to .osu file
187
+ audio_path: str = '' # Path to input audio
188
+ raw_output: bool = False
189
+ precision: str = 'fp32' # Lower precision for speed (fp32/bf16/amp)
190
+ inference: InferenceConfig = field(default_factory=InferenceConfig) # Training settings for osuT5 model
191
+ hydra: Any = MISSING
192
+
193
+
194
+ cs = ConfigStore.instance()
195
+ cs.store(group="inference", name="base", node=InferenceConfig)
196
+ cs.store(name="base_fid", node=FidConfig)
197
+ cs.store(name="base_mai_mod", node=MaiModConfig)
configs/calc_fid.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base_fid
3
+ - inference: tiny_dist7
4
+ - _self_
5
+
6
+ compile: false
7
+ num_processes: 32
8
+ seed: 0
9
+
10
+ skip_generation: false
11
+ fid: true
12
+ rhythm_stats: true
13
+
14
+ classifier_ckpt: 'OliBomby/osu-classifier'
15
+ classifier_batch_size: 32
16
+
17
+ training_set_ids_path: null
18
+
19
+ dataset_type: "mmrs"
20
+ dataset_path: C:/Users/Olivier/Documents/Collections/Beatmap ML Datasets/MMRS2025
21
+ dataset_start: 0
22
+ dataset_end: 106 # Contains 324 std beatmaps
23
+ gamemodes: [0] # List of gamemodes to include in the dataset
24
+
25
+ inference:
26
+ super_timing: false
27
+ temperature: 0.9 # Sampling temperature
28
+ top_p: 0.9 # Top-p sampling threshold
29
+ lookback: 0.5 # Fraction of audio sequence to fill with tokens from previous inference window
30
+ lookahead: 0.4 # Fraction of audio sequence to skip at the end of the audio window
31
+ year: 2023
32
+ resnap_events: false
33
+ use_server: false
34
+
35
+ hydra:
36
+ job:
37
+ chdir: True
38
+ run:
39
+ # dir: ./logs_fid/${now:%Y-%m-%d}/${now:%H-%M-%S}
40
+ dir: ./logs_fid/test
41
+ sweep:
42
+ dir: ./logs_fid/sweeps/test_3
43
+ subdir: ${hydra.job.override_dirname}