Skip to content

Commit

Permalink
Implement jameslamb's review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
borchero committed Nov 8, 2023
1 parent 678ae7d commit 5331202
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions tests/python_package_test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,21 +104,26 @@ def test_dataset_construct_fuzzy(
# -------------------------------------------- FIELDS ------------------------------------------- #


@pytest.mark.parametrize("field", ["label", "weight"])
def test_dataset_construct_fields_fuzzy(field: str):
def test_dataset_construct_fields_fuzzy():
arrow_table = generate_random_arrow_table(3, 1000, 42)
arrow_array = generate_random_arrow_array(1000, 42)
arrow_labels = generate_random_arrow_array(1000, 42)
arrow_weights = generate_random_arrow_array(1000, 42)

arrow_dataset = lgb.Dataset(arrow_table, **{field: arrow_array})
arrow_dataset = lgb.Dataset(arrow_table, label=arrow_labels, weight=arrow_weights)
arrow_dataset.construct()

pandas_dataset = lgb.Dataset(arrow_table.to_pandas(), **{field: arrow_array.to_numpy()})
pandas_dataset = lgb.Dataset(
arrow_table.to_pandas(), label=arrow_labels.to_numpy(), weight=arrow_weights.to_numpy()
)
pandas_dataset.construct()

np_assert_array_equal(arrow_dataset.get_field(field), pandas_dataset.get_field(field))
np_assert_array_equal(
getattr(arrow_dataset, f"get_{field}")(), getattr(pandas_dataset, f"get_{field}")()
)
# Check for equality
for field in ("label", "weight"):
np_assert_array_equal(
arrow_dataset.get_field(field), pandas_dataset.get_field(field), strict=True
)
np_assert_array_equal(arrow_dataset.get_label(), pandas_dataset.get_label(), strict=True)
np_assert_array_equal(arrow_dataset.get_weight(), pandas_dataset.get_weight(), strict=True)


# -------------------------------------------- LABELS ------------------------------------------- #
Expand Down Expand Up @@ -150,7 +155,7 @@ def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type:
dataset.construct()

expected = np.array([0, 1, 0, 0, 1], dtype=np.float32)
np_assert_array_equal(expected, dataset.get_label())
np_assert_array_equal(expected, dataset.get_label(), strict=True)


# ------------------------------------------- WEIGHTS ------------------------------------------- #
Expand All @@ -162,6 +167,7 @@ def test_dataset_construct_weights_none():
dataset = lgb.Dataset(data, weight=weight, params=dummy_dataset_params())
dataset.construct()
assert dataset.get_weight() is None
assert dataset.get_field("weight") is None


@pytest.mark.parametrize(
Expand All @@ -176,4 +182,4 @@ def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type
dataset.construct()

expected = np.array([3, 0.7, 1.5, 0.5, 0.1], dtype=np.float32)
np_assert_array_equal(expected, dataset.get_weight())
np_assert_array_equal(expected, dataset.get_weight(), strict=True)

0 comments on commit 5331202

Please sign in to comment.