From 533120249e410c51431a3b8a10e71d9d580d59ad Mon Sep 17 00:00:00 2001 From: Oliver Borchert Date: Wed, 8 Nov 2023 10:31:46 +0100 Subject: [PATCH] Implement jameslamb's review comments --- tests/python_package_test/test_arrow.py | 28 +++++++++++++++---------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py index 3e188f647138..40482a904a62 100644 --- a/tests/python_package_test/test_arrow.py +++ b/tests/python_package_test/test_arrow.py @@ -104,21 +104,26 @@ def test_dataset_construct_fuzzy( # -------------------------------------------- FIELDS ------------------------------------------- # -@pytest.mark.parametrize("field", ["label", "weight"]) -def test_dataset_construct_fields_fuzzy(field: str): +def test_dataset_construct_fields_fuzzy(): arrow_table = generate_random_arrow_table(3, 1000, 42) - arrow_array = generate_random_arrow_array(1000, 42) + arrow_labels = generate_random_arrow_array(1000, 42) + arrow_weights = generate_random_arrow_array(1000, 42) - arrow_dataset = lgb.Dataset(arrow_table, **{field: arrow_array}) + arrow_dataset = lgb.Dataset(arrow_table, label=arrow_labels, weight=arrow_weights) arrow_dataset.construct() - pandas_dataset = lgb.Dataset(arrow_table.to_pandas(), **{field: arrow_array.to_numpy()}) + pandas_dataset = lgb.Dataset( + arrow_table.to_pandas(), label=arrow_labels.to_numpy(), weight=arrow_weights.to_numpy() + ) pandas_dataset.construct() - np_assert_array_equal(arrow_dataset.get_field(field), pandas_dataset.get_field(field)) - np_assert_array_equal( - getattr(arrow_dataset, f"get_{field}")(), getattr(pandas_dataset, f"get_{field}")() - ) + # Check for equality + for field in ("label", "weight"): + np_assert_array_equal( + arrow_dataset.get_field(field), pandas_dataset.get_field(field), strict=True + ) + np_assert_array_equal(arrow_dataset.get_label(), pandas_dataset.get_label(), strict=True) + np_assert_array_equal(arrow_dataset.get_weight(), pandas_dataset.get_weight(), strict=True) # -------------------------------------------- LABELS ------------------------------------------- # @@ -150,7 +155,7 @@ def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type: dataset.construct() expected = np.array([0, 1, 0, 0, 1], dtype=np.float32) - np_assert_array_equal(expected, dataset.get_label()) + np_assert_array_equal(expected, dataset.get_label(), strict=True) # ------------------------------------------- WEIGHTS ------------------------------------------- # @@ -162,6 +167,7 @@ def test_dataset_construct_weights_none(): dataset = lgb.Dataset(data, weight=weight, params=dummy_dataset_params()) dataset.construct() assert dataset.get_weight() is None + assert dataset.get_field("weight") is None @pytest.mark.parametrize( @@ -176,4 +182,4 @@ def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type dataset.construct() expected = np.array([3, 0.7, 1.5, 0.5, 0.1], dtype=np.float32) - np_assert_array_equal(expected, dataset.get_weight()) + np_assert_array_equal(expected, dataset.get_weight(), strict=True)