File size: 18,680 Bytes
12a8e0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
"""
BeatHeritage V1 Custom Postprocessor
Enhanced postprocessing for improved beatmap quality
"""

import numpy as np
from typing import List, Tuple, Dict, Optional
from dataclasses import dataclass
import logging

from osuT5.osuT5.inference.postprocessor import Postprocessor, BeatmapConfig

logger = logging.getLogger(__name__)


@dataclass
class BeatHeritageConfig(BeatmapConfig):
    """Enhanced config for BeatHeritage V1 postprocessing"""
    # Quality control parameters
    min_distance_threshold: float = 20.0
    max_overlap_ratio: float = 0.15
    enable_auto_correction: bool = True
    enable_flow_optimization: bool = True
    
    # Pattern enhancement
    enable_pattern_variety: bool = True
    pattern_complexity_target: float = 0.7
    
    # Difficulty scaling
    enable_difficulty_scaling: bool = True
    difficulty_variance_threshold: float = 0.3
    
    # Style preservation
    enable_style_preservation: bool = True
    style_consistency_weight: float = 0.8


class BeatHeritagePostprocessor(Postprocessor):
    """Enhanced postprocessor for BeatHeritage V1"""
    
    def __init__(self, config: BeatHeritageConfig):
        super().__init__(config)
        self.config = config
        self.flow_optimizer = FlowOptimizer(config)
        self.pattern_enhancer = PatternEnhancer(config)
        self.quality_controller = QualityController(config)
    
    def postprocess(self, beatmap_data: Dict) -> Dict:
        """
        Enhanced postprocessing pipeline for BeatHeritage V1
        
        Args:
            beatmap_data: Raw beatmap data from model
            
        Returns:
            Processed beatmap data with enhancements
        """
        # Base postprocessing
        beatmap_data = super().postprocess(beatmap_data)
        
        # Quality control
        if self.config.enable_auto_correction:
            beatmap_data = self.quality_controller.fix_spacing_issues(beatmap_data)
            beatmap_data = self.quality_controller.fix_overlaps(beatmap_data)
        
        # Flow optimization
        if self.config.enable_flow_optimization:
            beatmap_data = self.flow_optimizer.optimize_flow(beatmap_data)
        
        # Pattern enhancement
        if self.config.enable_pattern_variety:
            beatmap_data = self.pattern_enhancer.enhance_patterns(beatmap_data)
        
        # Difficulty scaling
        if self.config.enable_difficulty_scaling:
            beatmap_data = self._scale_difficulty(beatmap_data)
        
        # Style preservation
        if self.config.enable_style_preservation:
            beatmap_data = self._preserve_style(beatmap_data)
        
        return beatmap_data
    
    def _scale_difficulty(self, beatmap_data: Dict) -> Dict:
        """Scale difficulty to match target star rating"""
        target_difficulty = self.config.difficulty
        if target_difficulty is None:
            return beatmap_data
        
        current_difficulty = self._calculate_difficulty(beatmap_data)
        scale_factor = target_difficulty / max(current_difficulty, 0.1)
        
        # Adjust spacing and timing based on scale factor
        if 'hit_objects' in beatmap_data:
            for obj in beatmap_data['hit_objects']:
                if 'distance' in obj:
                    obj['distance'] *= scale_factor
        
        logger.info(f"Scaled difficulty from {current_difficulty:.2f} to {target_difficulty:.2f}")
        return beatmap_data
    
    def _preserve_style(self, beatmap_data: Dict) -> Dict:
        """Preserve mapping style consistency"""
        # Analyze style characteristics
        style_features = self._extract_style_features(beatmap_data)
        
        # Apply style consistency
        consistency_weight = self.config.style_consistency_weight
        
        if 'hit_objects' in beatmap_data:
            for i, obj in enumerate(beatmap_data['hit_objects']):
                if i > 0:
                    # Maintain consistent spacing patterns
                    prev_obj = beatmap_data['hit_objects'][i-1]
                    expected_distance = style_features.get('avg_distance', 100)
                    
                    if 'position' in obj and 'position' in prev_obj:
                        current_distance = self._calculate_distance(
                            obj['position'], prev_obj['position']
                        )
                        
                        # Blend current with expected based on consistency weight
                        adjusted_distance = (
                            current_distance * (1 - consistency_weight) +
                            expected_distance * consistency_weight
                        )
                        
                        # Adjust position to match distance
                        obj['position'] = self._adjust_position(
                            prev_obj['position'],
                            obj['position'],
                            adjusted_distance
                        )
        
        return beatmap_data
    
    def _calculate_difficulty(self, beatmap_data: Dict) -> float:
        """Calculate approximate star rating"""
        # Simplified difficulty calculation
        num_objects = len(beatmap_data.get('hit_objects', []))
        avg_spacing = self._calculate_avg_spacing(beatmap_data)
        bpm = beatmap_data.get('bpm', 180)
        
        # Simple formula (can be improved)
        difficulty = (num_objects / 100) * (avg_spacing / 50) * (bpm / 180)
        return min(max(difficulty, 0), 10)  # Clamp to 0-10
    
    def _extract_style_features(self, beatmap_data: Dict) -> Dict:
        """Extract style characteristics from beatmap"""
        features = {}
        
        if 'hit_objects' in beatmap_data:
            distances = []
            for i in range(1, len(beatmap_data['hit_objects'])):
                if 'position' in beatmap_data['hit_objects'][i]:
                    dist = self._calculate_distance(
                        beatmap_data['hit_objects'][i-1].get('position', (256, 192)),
                        beatmap_data['hit_objects'][i]['position']
                    )
                    distances.append(dist)
            
            if distances:
                features['avg_distance'] = np.mean(distances)
                features['distance_variance'] = np.var(distances)
        
        return features
    
    def _calculate_avg_spacing(self, beatmap_data: Dict) -> float:
        """Calculate average spacing between objects"""
        distances = []
        objects = beatmap_data.get('hit_objects', [])
        
        for i in range(1, len(objects)):
            if 'position' in objects[i] and 'position' in objects[i-1]:
                dist = self._calculate_distance(
                    objects[i-1]['position'],
                    objects[i]['position']
                )
                distances.append(dist)
        
        return np.mean(distances) if distances else 100
    
    def _calculate_distance(self, pos1: Tuple[float, float], 
                          pos2: Tuple[float, float]) -> float:
        """Calculate Euclidean distance between two positions"""
        return np.sqrt((pos1[0] - pos2[0])**2 + (pos1[1] - pos2[1])**2)
    
    def _adjust_position(self, from_pos: Tuple[float, float],
                        to_pos: Tuple[float, float],
                        target_distance: float) -> Tuple[float, float]:
        """Adjust position to achieve target distance"""
        current_distance = self._calculate_distance(from_pos, to_pos)
        if current_distance < 0.01:  # Avoid division by zero
            return to_pos
        
        scale = target_distance / current_distance
        dx = (to_pos[0] - from_pos[0]) * scale
        dy = (to_pos[1] - from_pos[1]) * scale
        
        # Keep within playfield bounds
        new_x = max(0, min(512, from_pos[0] + dx))
        new_y = max(0, min(384, from_pos[1] + dy))
        
        return (new_x, new_y)


class FlowOptimizer:
    """Optimize flow patterns in beatmaps"""
    
    def __init__(self, config: BeatHeritageConfig):
        self.config = config
    
    def optimize_flow(self, beatmap_data: Dict) -> Dict:
        """Optimize flow for better playability"""
        if 'hit_objects' not in beatmap_data:
            return beatmap_data
        
        objects = beatmap_data['hit_objects']
        optimized_objects = []
        
        for i, obj in enumerate(objects):
            if i >= 2 and 'position' in obj:
                # Calculate flow angle
                prev_angle = self._calculate_angle(
                    objects[i-2].get('position', (256, 192)),
                    objects[i-1].get('position', (256, 192))
                )
                current_angle = self._calculate_angle(
                    objects[i-1].get('position', (256, 192)),
                    obj['position']
                )
                
                # Smooth sharp angles
                angle_diff = abs(current_angle - prev_angle)
                if angle_diff > 120:  # Sharp angle threshold
                    # Adjust position for smoother flow
                    smoothed_angle = prev_angle + np.sign(current_angle - prev_angle) * 90
                    distance = self._calculate_distance(
                        objects[i-1]['position'],
                        obj['position']
                    )
                    
                    new_x = objects[i-1]['position'][0] + distance * np.cos(np.radians(smoothed_angle))
                    new_y = objects[i-1]['position'][1] + distance * np.sin(np.radians(smoothed_angle))
                    
                    obj['position'] = (
                        max(0, min(512, new_x)),
                        max(0, min(384, new_y))
                    )
            
            optimized_objects.append(obj)
        
        beatmap_data['hit_objects'] = optimized_objects
        return beatmap_data
    
    def _calculate_angle(self, pos1: Tuple[float, float],
                        pos2: Tuple[float, float]) -> float:
        """Calculate angle between two positions in degrees"""
        return np.degrees(np.arctan2(pos2[1] - pos1[1], pos2[0] - pos1[0]))
    
    def _calculate_distance(self, pos1: Tuple[float, float],
                          pos2: Tuple[float, float]) -> float:
        """Calculate Euclidean distance"""
        return np.sqrt((pos1[0] - pos2[0])**2 + (pos1[1] - pos2[1])**2)


class PatternEnhancer:
    """Enhance pattern variety in beatmaps"""
    
    def __init__(self, config: BeatHeritageConfig):
        self.config = config
        self.pattern_library = self._load_pattern_library()
    
    def enhance_patterns(self, beatmap_data: Dict) -> Dict:
        """Enhance patterns for more variety"""
        if 'hit_objects' not in beatmap_data:
            return beatmap_data
        
        # Detect repetitive patterns
        repetitive_sections = self._detect_repetitive_patterns(beatmap_data)
        
        # Replace with varied patterns
        for section in repetitive_sections:
            beatmap_data = self._vary_pattern(beatmap_data, section)
        
        return beatmap_data
    
    def _load_pattern_library(self) -> List[Dict]:
        """Load common mapping patterns"""
        return [
            {'name': 'triangle', 'positions': [(0, 0), (100, 0), (50, 86.6)]},
            {'name': 'square', 'positions': [(0, 0), (100, 0), (100, 100), (0, 100)]},
            {'name': 'star', 'positions': [(50, 0), (61, 35), (97, 35), (68, 57), (79, 91), (50, 70), (21, 91), (32, 57), (3, 35), (39, 35)]},
            {'name': 'hexagon', 'positions': [(50, 0), (93, 25), (93, 75), (50, 100), (7, 75), (7, 25)]},
        ]
    
    def _detect_repetitive_patterns(self, beatmap_data: Dict) -> List[Tuple[int, int]]:
        """Detect sections with repetitive patterns"""
        repetitive_sections = []
        objects = beatmap_data.get('hit_objects', [])
        
        window_size = 8
        for i in range(len(objects) - window_size * 2):
            pattern1 = self._extract_pattern(objects[i:i+window_size])
            pattern2 = self._extract_pattern(objects[i+window_size:i+window_size*2])
            
            if self._patterns_similar(pattern1, pattern2):
                repetitive_sections.append((i, i + window_size * 2))
        
        return repetitive_sections
    
    def _extract_pattern(self, objects: List[Dict]) -> List[Tuple[float, float]]:
        """Extract position pattern from objects"""
        return [obj.get('position', (256, 192)) for obj in objects]
    
    def _patterns_similar(self, pattern1: List, pattern2: List, threshold: float = 0.8) -> bool:
        """Check if two patterns are similar"""
        if len(pattern1) != len(pattern2):
            return False
        
        distances = []
        for pos1, pos2 in zip(pattern1, pattern2):
            dist = np.sqrt((pos1[0] - pos2[0])**2 + (pos1[1] - pos2[1])**2)
            distances.append(dist)
        
        avg_distance = np.mean(distances)
        return avg_distance < 50  # Threshold for similarity
    
    def _vary_pattern(self, beatmap_data: Dict, section: Tuple[int, int]) -> Dict:
        """Apply variation to a pattern section"""
        start, end = section
        objects = beatmap_data['hit_objects']
        
        # Select a random pattern from library
        pattern = np.random.choice(self.pattern_library)
        pattern_positions = pattern['positions']
        
        # Apply pattern with scaling
        section_length = end - start
        for i in range(start, min(end, len(objects))):
            if 'position' in objects[i]:
                pattern_idx = (i - start) % len(pattern_positions)
                base_pos = pattern_positions[pattern_idx]
                
                # Scale and translate pattern
                center = (256, 192)
                scale = 2.0
                
                new_x = center[0] + base_pos[0] * scale
                new_y = center[1] + base_pos[1] * scale
                
                objects[i]['position'] = (
                    max(0, min(512, new_x)),
                    max(0, min(384, new_y))
                )
        
        return beatmap_data


class QualityController:
    """Control quality aspects of beatmaps"""
    
    def __init__(self, config: BeatHeritageConfig):
        self.config = config
    
    def fix_spacing_issues(self, beatmap_data: Dict) -> Dict:
        """Fix objects that are too close together"""
        if 'hit_objects' not in beatmap_data:
            return beatmap_data
        
        objects = beatmap_data['hit_objects']
        min_distance = self.config.min_distance_threshold
        
        for i in range(1, len(objects)):
            if 'position' in objects[i] and 'position' in objects[i-1]:
                distance = self._calculate_distance(
                    objects[i-1]['position'],
                    objects[i]['position']
                )
                
                if distance < min_distance:
                    # Move object to maintain minimum distance
                    direction = self._get_direction(
                        objects[i-1]['position'],
                        objects[i]['position']
                    )
                    
                    objects[i]['position'] = self._move_position(
                        objects[i-1]['position'],
                        direction,
                        min_distance
                    )
        
        return beatmap_data
    
    def fix_overlaps(self, beatmap_data: Dict) -> Dict:
        """Fix overlapping sliders and circles"""
        if 'hit_objects' not in beatmap_data:
            return beatmap_data
        
        objects = beatmap_data['hit_objects']
        max_overlap = self.config.max_overlap_ratio
        
        for i in range(len(objects)):
            for j in range(i+1, min(i+10, len(objects))):  # Check next 10 objects
                if self._objects_overlap(objects[i], objects[j], max_overlap):
                    # Adjust position to reduce overlap
                    objects[j] = self._adjust_for_overlap(objects[i], objects[j])
        
        return beatmap_data
    
    def _calculate_distance(self, pos1: Tuple[float, float],
                          pos2: Tuple[float, float]) -> float:
        """Calculate Euclidean distance"""
        return np.sqrt((pos1[0] - pos2[0])**2 + (pos1[1] - pos2[1])**2)
    
    def _get_direction(self, from_pos: Tuple[float, float],
                      to_pos: Tuple[float, float]) -> Tuple[float, float]:
        """Get normalized direction vector"""
        dx = to_pos[0] - from_pos[0]
        dy = to_pos[1] - from_pos[1]
        
        length = np.sqrt(dx**2 + dy**2)
        if length < 0.01:
            return (1, 0)  # Default right direction
        
        return (dx / length, dy / length)
    
    def _move_position(self, from_pos: Tuple[float, float],
                      direction: Tuple[float, float],
                      distance: float) -> Tuple[float, float]:
        """Move position in direction by distance"""
        new_x = from_pos[0] + direction[0] * distance
        new_y = from_pos[1] + direction[1] * distance
        
        # Keep within bounds
        return (
            max(0, min(512, new_x)),
            max(0, min(384, new_y))
        )
    
    def _objects_overlap(self, obj1: Dict, obj2: Dict, threshold: float) -> bool:
        """Check if two objects overlap beyond threshold"""
        if 'position' not in obj1 or 'position' not in obj2:
            return False
        
        distance = self._calculate_distance(obj1['position'], obj2['position'])
        
        # Simple overlap check (can be improved for sliders)
        radius = 30  # Approximate circle radius
        overlap = max(0, 2 * radius - distance) / (2 * radius)
        
        return overlap > threshold
    
    def _adjust_for_overlap(self, obj1: Dict, obj2: Dict) -> Dict:
        """Adjust object position to reduce overlap"""
        if 'position' not in obj1 or 'position' not in obj2:
            return obj2
        
        # Move obj2 away from obj1
        direction = self._get_direction(obj1['position'], obj2['position'])
        min_safe_distance = 60  # Minimum safe distance
        
        obj2['position'] = self._move_position(
            obj1['position'],
            direction,
            min_safe_distance
        )
        
        return obj2


# Export main postprocessor
__all__ = ['BeatHeritagePostprocessor', 'BeatHeritageConfig']