Fix typo
Browse files
custom_generate/generate.py
CHANGED
|
@@ -243,7 +243,7 @@ def _group_beam_search(
|
|
| 243 |
batch_beam_size, cur_len = input_ids.shape
|
| 244 |
# Does not exist anymore in recent versions!
|
| 245 |
if hasattr(model, "_get_initial_cache_position"):
|
| 246 |
-
model_kwargs = model._get_initial_cache_position(
|
| 247 |
|
| 248 |
if return_dict_in_generate and output_scores:
|
| 249 |
beam_indices = [
|
|
|
|
| 243 |
batch_beam_size, cur_len = input_ids.shape
|
| 244 |
# Does not exist anymore in recent versions!
|
| 245 |
if hasattr(model, "_get_initial_cache_position"):
|
| 246 |
+
model_kwargs = model._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
| 247 |
|
| 248 |
if return_dict_in_generate and output_scores:
|
| 249 |
beam_indices = [
|