Skip to content

Commit

Permalink
Merge pull request #179 from AIRobotZhang/patch-1
Browse files Browse the repository at this point in the history
Update loader.py
  • Loading branch information
tnlin authored Jan 3, 2025
2 parents e0cc87a + 67e9698 commit 8dd65d5
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions IOPO/Method-IOPO/src/llamafactory/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,14 @@ def _get_merged_dataset(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
stage: Literal["pt", "sft", "rm", "ppo", "kto", "iopo"],
) -> Optional[Union["Dataset", "IterableDataset"]]:
if dataset_names is None:
return None

datasets = []
for dataset_attr in get_dataset_list(dataset_names, data_args.dataset_dir):
if (stage == "rm" and dataset_attr.ranking is False) or (stage != "rm" and dataset_attr.ranking is True):
if ((stage == "rm" or stage == "iopo") and dataset_attr.ranking is False) or ((stage != "rm" and stage != "iopo") and dataset_attr.ranking is True):
raise ValueError("The dataset is not applicable in the current training stage.")

datasets.append(_load_single_dataset(dataset_attr, model_args, data_args, training_args))
Expand Down Expand Up @@ -199,7 +199,7 @@ def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
stage: Literal["pt", "sft", "rm", "ppo", "kto", "iopo"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None,
) -> "DatasetModule":
Expand Down

0 comments on commit 8dd65d5

Please sign in to comment.