CM3P-ranked-classifier / parsing_cm3p.py
OliBomby's picture
Upload 3 files
8781d03 verified
import dataclasses
from datetime import timedelta
from enum import Enum
from os import PathLike
from typing import Optional, Union, IO
import numpy as np
import numpy.typing as npt
from slider import Beatmap, Circle, Slider, Spinner, HoldNote, TimingPoint
from slider.curve import Linear, Catmull, Perfect, MultiBezier
from transformers import FeatureExtractionMixin, AutoFeatureExtractor
from .configuration_cm3p import CM3PConfig
class EventType(Enum):
CIRCLE = "circle"
SPINNER = "spinner"
SPINNER_END = "spinner_end"
SLIDER_HEAD = "slider_head"
BEZIER_ANCHOR = "bezier_anchor"
PERFECT_ANCHOR = "perfect_anchor"
CATMULL_ANCHOR = "catmull_anchor"
RED_ANCHOR = "red_anchor"
LAST_ANCHOR = "last_anchor"
SLIDER_END = "slider_end"
REPEAT_END = "repeat_end"
BEAT = "beat"
MEASURE = "measure"
TIMING_POINT = "timing_point"
KIAI_ON = "kiai_on"
KIAI_OFF = "kiai_off"
HOLD_NOTE = "hold_note"
HOLD_NOTE_END = "hold_note_end"
SCROLL_SPEED_CHANGE = "scroll_speed_change"
DRUMROLL = "drumroll"
DRUMROLL_END = "drumroll_end"
DENDEN = "denden"
DENDEN_END = "denden_end"
EVENT_TYPES_WITH_NEW_COMBO = [
EventType.CIRCLE,
EventType.SLIDER_HEAD,
]
@dataclasses.dataclass
class Group:
event_type: EventType = None
time: int = 0
has_time: bool = False
snapping: int = None
distance: int = None
x: int = None
y: int = None
mania_column: int = None
new_combo: bool = False
hitsounds: list[int] = dataclasses.field(default_factory=list)
samplesets: list[int] = dataclasses.field(default_factory=list)
additions: list[int] = dataclasses.field(default_factory=list)
volumes: list[int] = dataclasses.field(default_factory=list)
scroll_speed: float = None
def merge_groups(groups1: list[Group], groups2: list[Group]) -> list[Group]:
"""Merge two lists of groups in a time sorted manner. Assumes both lists are sorted by time.
Args:
groups1: List of groups.
groups2: List of groups.
Returns:
merged_groups: Merged list of groups.
"""
merged_groups = []
i = 0
j = 0
t1 = -np.inf
t2 = -np.inf
while i < len(groups1) and j < len(groups2):
t1 = groups1[i].time or t1
t2 = groups2[j].time or t2
if t1 <= t2:
merged_groups.append(groups1[i])
i += 1
else:
merged_groups.append(groups2[j])
j += 1
# Add remaining groups from both lists
merged_groups.extend(groups1[i:])
merged_groups.extend(groups2[j:])
return merged_groups
def speed_groups(groups: list[Group], speed: float) -> list[Group]:
"""Change the speed of a list of groups.
Args:
groups: List of groups.
speed: Speed multiplier.
Returns:
sped_groups: Sped up list of groups.
"""
sped_groups = []
for group in groups:
group.time = int(group.time / speed)
sped_groups.append(group)
return sped_groups
def get_median_mpb_beatmap(beatmap: Beatmap) -> float:
# Not include last slider's end time
last_time = max(ho.end_time if isinstance(ho, HoldNote) else ho.time for ho in beatmap.hit_objects(stacking=False))
last_time = int(last_time.seconds * 1000)
return get_median_mpb(beatmap.timing_points, last_time)
def get_median_mpb(timing_points: list[TimingPoint], last_time: float) -> float:
# This is identical to osu! stable implementation
this_beat_length = 0
bpm_durations = {}
for i in range(len(timing_points) - 1, -1, -1):
tp = timing_points[i]
offset = int(tp.offset.seconds * 1000)
if tp.parent is None:
this_beat_length = tp.ms_per_beat
if this_beat_length == 0 or offset > last_time or (tp.parent is not None and i > 0):
continue
if this_beat_length in bpm_durations:
bpm_durations[this_beat_length] += int(last_time - (0 if i == 0 else offset))
else:
bpm_durations[this_beat_length] = int(last_time - (0 if i == 0 else offset))
last_time = offset
longest_time = 0
median = 0
for bpm, duration in bpm_durations.items():
if duration > longest_time:
longest_time = duration
median = bpm
return median
def load_beatmap(beatmap: Union[str, PathLike, IO[str], Beatmap]) -> Beatmap:
"""Load a beatmap from a file path, file object, or Beatmap object.
Args:
beatmap: Beatmap file path, file object, or Beatmap object.
Returns:
beatmap: Loaded Beatmap object.
"""
if isinstance(beatmap, (str, PathLike)):
beatmap = Beatmap.from_path(beatmap)
elif isinstance(beatmap, IO):
beatmap = Beatmap.from_file(beatmap.name)
return beatmap
def get_song_length(
samples: np.ndarray = None,
sample_rate: int = None,
beatmap: Union[Beatmap | list[TimingPoint]] = None,
) -> float:
if samples is not None and sample_rate is not None:
return len(samples) / sample_rate
if beatmap is None:
return 0
if isinstance(beatmap, Beatmap) and len(beatmap.hit_objects(stacking=False)) > 0:
last_ho = beatmap.hit_objects(stacking=False)[-1]
last_time = last_ho.end_time if hasattr(last_ho, "end_time") else last_ho.time
return last_time.total_seconds() + 0.000999 # Add a small buffer to the last time
timing = beatmap.timing_points if isinstance(beatmap, Beatmap) else beatmap
if len(timing) == 0:
return 0
return timing[-1].offset.total_seconds() + 0.01
class CM3PBeatmapParser(FeatureExtractionMixin):
"""
A class to parse CM3P beatmap files.
"""
def __init__(
self,
add_timing: bool = True,
add_snapping: bool = True,
add_timing_points: bool = True,
add_hitsounds: bool = True,
add_distances: bool = True,
add_positions: bool = True,
add_kiai: bool = True,
add_sv: bool = True,
add_mania_sv: bool = True,
mania_bpm_normalized_scroll_speed: bool = True,
slider_version: int = 2,
**kwargs,
):
self.add_timing = add_timing
self.add_snapping = add_snapping
self.add_timing_points = add_timing_points
self.add_hitsounds = add_hitsounds
self.add_distances = add_distances
self.add_positions = add_positions
self.add_kiai = add_kiai
self.add_sv = add_sv
self.add_mania_sv = add_mania_sv
self.mania_bpm_normalized_scroll_speed = mania_bpm_normalized_scroll_speed
self.slider_version = slider_version
super().__init__(**kwargs)
def parse_beatmap(
self,
beatmap: Union[str, PathLike, IO[str], Beatmap],
speed: float = 1.0,
song_length: Optional[float] = None
) -> list[Group]:
"""Parse an .osu beatmap.
Each hit object is parsed into a list of Event objects, in order of its
appearance in the beatmap. In other words, in ascending order of time.
Args:
beatmap: Beatmap object parsed from an .osu file.
speed: Speed multiplier for the beatmap.
song_length: Length of the song in seconds. If not provided, it will be calculated from the beatmap.
Returns:
events: List of Event object lists.
event_times: List of event times.
"""
beatmap = load_beatmap(beatmap)
hit_objects = beatmap.hit_objects(stacking=False)
last_pos = np.array((256, 192))
groups = []
for hit_object in hit_objects:
if isinstance(hit_object, Circle):
last_pos = self._parse_circle(hit_object, groups, last_pos, beatmap)
elif isinstance(hit_object, Slider):
if beatmap.mode == 1:
self._parse_drumroll(hit_object, groups, beatmap)
else:
last_pos = self._parse_slider(hit_object, groups, last_pos, beatmap)
elif isinstance(hit_object, Spinner):
if beatmap.mode == 1:
self._parse_denden(hit_object, groups, beatmap)
else:
last_pos = self._parse_spinner(hit_object, groups, beatmap)
elif isinstance(hit_object, HoldNote):
last_pos = self._parse_hold_note(hit_object, groups, beatmap)
# Sort groups by time
if len(groups) > 0:
groups = sorted(groups, key=lambda x: x.time)
result = list(groups)
if self.add_mania_sv and beatmap.mode == 3:
scroll_speed_events = self.parse_scroll_speeds(beatmap)
result = merge_groups(scroll_speed_events, result)
if self.add_kiai:
kiai_events = self.parse_kiai(beatmap)
result = merge_groups(kiai_events, result)
if self.add_timing:
timing_events = self.parse_timing(beatmap, song_length=song_length)
result = merge_groups(timing_events, result)
if speed != 1.0:
result = speed_groups(result, speed)
return result
def parse_scroll_speeds(self, beatmap: Beatmap, speed: float = 1.0) -> list[Group]:
"""Extract all BPM-normalized scroll speed changes from a beatmap."""
normalized = self.mania_bpm_normalized_scroll_speed
groups = []
median_mpb = get_median_mpb_beatmap(beatmap)
mpb = median_mpb
last_normalized_scroll_speed = -1
for i, tp in enumerate(beatmap.timing_points):
if tp.parent is None:
mpb = tp.ms_per_beat
scroll_speed = 1
else:
scroll_speed = -100 / tp.ms_per_beat
if i == len(beatmap.timing_points) - 1 or beatmap.timing_points[i + 1].offset > tp.offset:
normalized_scroll_speed = scroll_speed * median_mpb / mpb if normalized else scroll_speed
if normalized_scroll_speed != last_normalized_scroll_speed or last_normalized_scroll_speed == -1:
self._add_group(
EventType.SCROLL_SPEED_CHANGE,
groups,
time=tp.offset,
beatmap=beatmap,
scroll_speed=normalized_scroll_speed,
)
last_normalized_scroll_speed = normalized_scroll_speed
if speed != 1.0:
groups = speed_groups(groups, speed)
return groups
def parse_kiai(self, beatmap: Beatmap, speed: float = 1.0) -> list[Group]:
"""Extract all kiai information from a beatmap."""
groups = []
kiai = False
for tp in beatmap.timing_points:
if tp.kiai_mode == kiai:
continue
self._add_group(
EventType.KIAI_ON if tp.kiai_mode else EventType.KIAI_OFF,
groups,
time=tp.offset,
beatmap=beatmap,
)
kiai = tp.kiai_mode
if speed != 1.0:
groups = speed_groups(groups, speed)
return groups
def parse_timing(self, beatmap: Beatmap | list[TimingPoint], speed: float = 1.0, song_length: Optional[float] = None) -> list[Group]:
"""Extract all timing information from a beatmap."""
timing = beatmap.timing_points if isinstance(beatmap, Beatmap) else beatmap
assert len(timing) > 0, "No timing points found in beatmap."
groups = []
last_time = song_length or get_song_length(beatmap=beatmap)
last_time = int(last_time * 1000)
# Get all timing points with BPM changes
timing_points = [tp for tp in timing if tp.bpm]
for i, tp in enumerate(timing_points):
# Generate beat and measure events until the next timing point
next_tp = timing_points[i + 1] if i + 1 < len(timing_points) else None
next_time = next_tp.offset.total_seconds() * 1000 - 10 if next_tp else last_time
start_time = tp.offset.total_seconds() * 1000
time = start_time
measure_counter = 0
beat_delta = tp.ms_per_beat
while time <= next_time:
if self.add_timing_points and measure_counter == 0:
event_type = EventType.TIMING_POINT
elif measure_counter % tp.meter == 0:
event_type = EventType.MEASURE
else:
event_type = EventType.BEAT
self._add_group(
event_type,
groups,
time=timedelta(milliseconds=time),
add_snap=False,
)
# Exit early if the beat_delta is too small to avoid infinite loops
if beat_delta <= 10:
break
measure_counter += 1
time = start_time + measure_counter * beat_delta
if speed != 1.0:
groups = speed_groups(groups, speed)
return groups
@staticmethod
def uninherited_point_at(time: timedelta, beatmap: Beatmap):
tp = beatmap.timing_point_at(time)
return tp if tp.parent is None else tp.parent
@staticmethod
def hitsound_point_at(time: timedelta, beatmap: Beatmap):
hs_query = time + timedelta(milliseconds=5)
return beatmap.timing_point_at(hs_query)
def scroll_speed_at(self, time: timedelta, beatmap: Beatmap) -> float:
query = time
tp = beatmap.timing_point_at(query)
return self.tp_to_scroll_speed(tp)
def tp_to_scroll_speed(self, tp: TimingPoint) -> float:
if tp.parent is None or tp.ms_per_beat >= 0 or np.isnan(tp.ms_per_beat):
return 1
else:
return np.clip(-100 / tp.ms_per_beat, 0.01, 10)
def _get_snapping(self, time: timedelta, beatmap: Beatmap, add_snap: bool = True) -> int:
"""Add a snapping event to the event list.
Args:
time: Time of the snapping event.
beatmap: Beatmap object.
add_snap: Whether to add a snapping event.
"""
if not add_snap or not self.add_snapping:
return None
tp = self.uninherited_point_at(time, beatmap)
beats = (time - tp.offset).total_seconds() * 1000 / tp.ms_per_beat
snapping = 0
for i in range(1, 17):
# If the difference between the time and the snapped time is less than 2 ms, that is the correct snapping
if abs(beats - round(beats * i) / i) * tp.ms_per_beat < 2:
snapping = i
break
return snapping
def _get_hitsounds(self, time: timedelta, hitsound: int, addition: str, beatmap: Beatmap) -> tuple[int, int, int, int]:
tp = self.hitsound_point_at(time, beatmap)
tp_sample_set = tp.sample_type if tp.sample_type != 0 else 2 # Inherit to soft sample set
addition_split = addition.split(":")
sample_set = int(addition_split[0]) if addition_split[0] != "0" else tp_sample_set
addition_set = int(addition_split[1]) if addition_split[1] != "0" else sample_set
volume = int(addition_split[3]) if len(addition_split) > 3 and addition_split[3] != "0" else tp.volume
sample_set = sample_set if 0 < sample_set < 4 else 1 # Overflow default to normal sample set
addition_set = addition_set if 0 < addition_set < 4 else 1 # Overflow default to normal sample set
hitsound = hitsound & 14 # Only take the bits for whistle, finish, and clap
volume = np.clip(volume, 0, 100)
return hitsound, sample_set, addition_set, volume
def _get_position(self, pos: npt.NDArray, last_pos: npt.NDArray) -> tuple[int, int, int, npt.NDArray]:
x, y, dist = None, None, None
if self.add_distances:
dist = int(np.linalg.norm(pos - last_pos))
if self.add_positions:
x = int(pos[0])
y = int(pos[1])
return x, y, dist, pos
def _get_mania_column(self, pos: npt.NDArray, columns: int) -> int:
column = int(np.clip(pos[0] / 512 * columns, 0, columns - 1))
return column
def _add_group(
self,
event_type: EventType,
groups: list[Group],
time: timedelta,
*,
beatmap: Beatmap = None,
add_snap: bool = True,
has_time: bool = True,
pos: npt.NDArray = None,
last_pos: npt.NDArray = None,
new_combo: bool = False,
hitsound_ref_times: list[timedelta] = None,
hitsounds: list[int] = None,
additions: list[str] = None,
scroll_speed: Optional[float] = None,
) -> npt.NDArray:
"""Add a group of events to the event list."""
group = Group(
event_type=event_type,
time=int(time.total_seconds() * 1000 + 1e-5)
)
if has_time:
group.has_time = True
group.snapping = self._get_snapping(time, beatmap, add_snap)
if pos is not None:
if beatmap.mode in [0, 2]:
x, y, dist, last_pos = self._get_position(pos, last_pos)
group.x = x
group.y = y
group.distance = dist
elif beatmap.mode == 3:
group.column = self._get_mania_column(pos, int(beatmap.circle_size))
if new_combo and beatmap.mode in [0, 2]:
group.new_combo = True
if scroll_speed is not None:
group.scroll_speed = scroll_speed
if hitsound_ref_times is not None and self.add_hitsounds:
for i, ref_time in enumerate(hitsound_ref_times):
hitsound, sample_set, addition_set, volume = self._get_hitsounds(ref_time, hitsounds[i], additions[i], beatmap)
group.hitsounds.append(hitsound)
group.samplesets.append(sample_set)
group.additions.append(addition_set)
group.volumes.append(volume)
groups.append(group)
return last_pos
def _parse_circle(self, circle: Circle, groups: list[Group], last_pos: npt.NDArray, beatmap: Beatmap) -> npt.NDArray:
"""Parse a circle hit object.
Args:
circle: Circle object.
groups: List of groups to add to.
last_pos: Last position of the hit objects.
Returns:
pos: Position of the circle.
"""
return self._add_group(
EventType.CIRCLE,
groups,
time=circle.time,
beatmap=beatmap,
pos=np.array(circle.position),
last_pos=last_pos,
new_combo=circle.new_combo,
hitsound_ref_times=[circle.time],
hitsounds=[circle.hitsound],
additions=[circle.addition],
scroll_speed=self.scroll_speed_at(circle.time, beatmap) if beatmap.mode == 1 else None,
)
def _parse_slider(self, slider: Slider, groups: list[Group], last_pos: npt.NDArray, beatmap: Beatmap) -> npt.NDArray:
"""Parse a slider hit object.
Args:
slider: Slider object.
groups: List of groups to add to.
last_pos: Last position of the hit objects.
Returns:
pos: Last position of the slider.
"""
# Ignore sliders which are too big
if len(slider.curve.points) >= 100:
return last_pos
last_pos = self._add_group(
EventType.SLIDER_HEAD,
groups,
time=slider.time,
beatmap=beatmap,
pos=np.array(slider.position),
last_pos=last_pos,
new_combo=slider.new_combo,
hitsound_ref_times=[slider.time],
hitsounds=[slider.edge_sounds[0] if len(slider.edge_sounds) > 0 else 0],
additions=[slider.edge_additions[0] if len(slider.edge_additions) > 0 else '0:0'],
scroll_speed=self.scroll_speed_at(slider.time, beatmap) if self.add_sv else None,
)
duration: timedelta = (slider.end_time - slider.time) / slider.repeat
control_point_count = len(slider.curve.points)
def append_control_points(event_type: EventType, last_pos: npt.NDArray = last_pos) -> npt.NDArray:
for i in range(1, control_point_count - 1):
last_pos = add_anchor(event_type, i, last_pos)
return last_pos
def add_anchor(event_type: EventType, i: int, last_pos: npt.NDArray) -> npt.NDArray:
return self._add_group(
event_type,
groups,
time=slider.time + i / (control_point_count - 1) * duration if self.slider_version == 1 else slider.time,
beatmap=beatmap,
has_time=False,
pos=np.array(slider.curve.points[i]),
last_pos=last_pos,
)
if isinstance(slider.curve, Linear):
last_pos = append_control_points(EventType.RED_ANCHOR, last_pos)
elif isinstance(slider.curve, Catmull):
last_pos = append_control_points(EventType.CATMULL_ANCHOR, last_pos)
elif isinstance(slider.curve, Perfect):
last_pos = append_control_points(EventType.PERFECT_ANCHOR, last_pos)
elif isinstance(slider.curve, MultiBezier):
for i in range(1, control_point_count - 1):
if slider.curve.points[i] == slider.curve.points[i + 1]:
last_pos = add_anchor(EventType.RED_ANCHOR, i, last_pos)
elif slider.curve.points[i] != slider.curve.points[i - 1]:
last_pos = add_anchor(EventType.BEZIER_ANCHOR, i, last_pos)
if self.slider_version == 2:
# Add last control point without time
last_pos = self._add_group(
EventType.LAST_ANCHOR,
groups,
time=slider.time,
beatmap=beatmap,
has_time=False,
pos=np.array(slider.curve.points[-1]),
last_pos=last_pos,
)
# Add body hitsounds and remaining edge hitsounds
last_pos = self._add_group(
EventType.SLIDER_END,
groups,
time=slider.time + duration,
beatmap=beatmap,
pos=np.array(slider.curve.points[-1]) if self.slider_version == 1 else None,
last_pos=last_pos,
hitsound_ref_times=[slider.time + timedelta(milliseconds=1)] + [slider.time + i * duration for i in range(1, slider.repeat)],
hitsounds=[slider.hitsound] + [slider.edge_sounds[i] if len(slider.edge_sounds) > i else 0 for i in range(1, slider.repeat)],
additions=[slider.addition] + [slider.edge_additions[i] if len(slider.edge_additions) > i else '0:0' for i in range(1, slider.repeat)],
)
return self._add_group(
EventType.REPEAT_END,
groups,
time=slider.end_time,
beatmap=beatmap,
pos=np.array(slider.curve(1)),
last_pos=last_pos,
hitsound_ref_times=[slider.end_time],
hitsounds=[slider.edge_sounds[-1] if len(slider.edge_sounds) > 0 else 0],
additions=[slider.edge_additions[-1] if len(slider.edge_additions) > 0 else '0:0'],
)
def _parse_spinner(self, spinner: Spinner, groups: list[Group], beatmap: Beatmap) -> npt.NDArray:
"""Parse a spinner hit object.
Args:
spinner: Spinner object.
groups: List of groups to add to.
Returns:
pos: Last position of the spinner.
"""
self._add_group(
EventType.SPINNER,
groups,
time=spinner.time,
beatmap=beatmap,
)
self._add_group(
EventType.SPINNER_END,
groups,
time=spinner.end_time,
beatmap=beatmap,
hitsound_ref_times=[spinner.end_time],
hitsounds=[spinner.hitsound],
additions=[spinner.addition],
)
return np.array((256, 192))
def _parse_hold_note(self, hold_note: HoldNote, groups: list[Group], beatmap: Beatmap) -> npt.NDArray:
"""Parse a hold note hit object.
Args:
hold note: Hold note object.
groups: List of groups to add to.
Returns:
pos: Last position of the spinner.
"""
pos = np.array(hold_note.position)
self._add_group(
EventType.HOLD_NOTE,
groups,
time=hold_note.time,
beatmap=beatmap,
pos=pos,
hitsound_ref_times=[hold_note.time],
hitsounds=[hold_note.hitsound],
additions=[hold_note.addition],
)
self._add_group(
EventType.HOLD_NOTE_END,
groups,
time=hold_note.end_time,
beatmap=beatmap,
pos=pos,
)
return pos
def _parse_drumroll(self, slider: Slider, groups: list[Group], beatmap: Beatmap):
"""Parse a drumroll hit object.
Args:
slider: Slider object.
groups: List of groups to add to.
"""
self._add_group(
EventType.DRUMROLL,
groups,
time=slider.time,
beatmap=beatmap,
hitsound_ref_times=[slider.time],
hitsounds=[slider.hitsound], # Edge hitsounds are not supported in drumrolls
additions=[slider.addition],
scroll_speed=self.scroll_speed_at(slider.time, beatmap),
)
self._add_group(
EventType.DRUMROLL_END,
groups,
time=slider.end_time,
beatmap=beatmap,
)
def _parse_denden(self, spinner: Spinner, groups: list[Group], beatmap: Beatmap):
"""Parse a denden hit object.
Args:
spinner: Spinner object.
groups: List of groups to add to.
"""
self._add_group(
EventType.DENDEN,
groups,
time=spinner.time,
beatmap=beatmap,
hitsound_ref_times=[spinner.time],
hitsounds=[spinner.hitsound],
additions=[spinner.addition],
scroll_speed=self.scroll_speed_at(spinner.time, beatmap),
)
self._add_group(
EventType.DENDEN_END,
groups,
time=spinner.end_time,
beatmap=beatmap,
)
AutoFeatureExtractor.register(CM3PConfig, CM3PBeatmapParser)
__all__ = ["CM3PBeatmapParser", "EventType", "Group", "load_beatmap", "get_song_length", "EVENT_TYPES_WITH_NEW_COMBO"]