Skip to content

Commit

Permalink
feat(connect): distinct + sort (#3677)
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 authored Jan 15, 2025
1 parent 5702720 commit 34d2036
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 5 deletions.
10 changes: 9 additions & 1 deletion src/daft-connect/src/translation/expr/unresolved_function.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use daft_core::count_mode::CountMode;
use daft_dsl::col;
use daft_schema::dtype::DataType;
use eyre::{bail, Context};
use spark_connect::expression::UnresolvedFunction;

Expand Down Expand Up @@ -97,7 +99,13 @@ pub fn handle_count(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl:

let [arg] = arguments;

let count = arg.count(CountMode::All);
let arg = if arg.as_literal().and_then(|lit| lit.as_i32()) == Some(1i32) {
col("*")
} else {
arg
};

let count = arg.count(CountMode::All).cast(&DataType::Int64);

Ok(count)
}
Expand Down
95 changes: 91 additions & 4 deletions src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use daft_core::prelude::Schema;
use daft_dsl::LiteralValue;
use daft_dsl::{col, LiteralValue};
use daft_logical_plan::{LogicalPlanBuilder, PyLogicalPlanBuilder};
use daft_micropartition::{
partitioning::{
Expand All @@ -11,12 +11,19 @@ use daft_micropartition::{
MicroPartition,
};
use daft_table::Table;
use eyre::bail;
use eyre::{bail, Context};
use futures::TryStreamExt;
use spark_connect::{relation::RelType, Limit, Relation, ShowString};
use spark_connect::{
expression::{
sort_order::{NullOrdering, SortDirection},
SortOrder,
},
relation::RelType,
Deduplicate, Limit, Relation, ShowString, Sort,
};
use tracing::debug;

use crate::{not_yet_implemented, session::Session, Runner};
use crate::{not_yet_implemented, session::Session, util::FromOptionalField, Runner};

mod aggregate;
mod drop;
Expand All @@ -31,6 +38,8 @@ mod with_columns_renamed;

use pyo3::{intern, prelude::*};

use super::to_daft_expr;

#[derive(Clone)]
pub struct SparkAnalyzer<'a> {
pub session: &'a Session,
Expand Down Expand Up @@ -135,6 +144,8 @@ impl SparkAnalyzer<'_> {
};
self.show_string(plan_id, *ss).await
}
RelType::Deduplicate(rel) => self.deduplicate(*rel).await,
RelType::Sort(rel) => self.sort(*rel).await,
plan => not_yet_implemented!(r#"relation type: "{}""#, rel_name(&plan))?,
}
}
Expand Down Expand Up @@ -196,6 +207,82 @@ impl SparkAnalyzer<'_> {

self.create_in_memory_scan(plan_id as _, schema, vec![tbl])
}

async fn deduplicate(&self, deduplicate: Deduplicate) -> eyre::Result<LogicalPlanBuilder> {
let Deduplicate {
input,
column_names,
..
} = deduplicate;

if !column_names.is_empty() {
not_yet_implemented!("Deduplicate with column names")?;
}

let input = input.required("input")?;

let plan = Box::pin(self.to_logical_plan(*input)).await?;

plan.distinct().map_err(Into::into)
}

async fn sort(&self, sort: Sort) -> eyre::Result<LogicalPlanBuilder> {
let Sort {
input,
order,
is_global,
} = sort;

let input = input.required("input")?;

if is_global == Some(false) {
not_yet_implemented!("Non Global sort")?;
}

let plan = Box::pin(self.to_logical_plan(*input)).await?;
if order.is_empty() {
return plan
.sort(vec![col("*")], vec![false], vec![false])
.map_err(Into::into);
}
let mut sort_by = Vec::with_capacity(order.len());
let mut descending = Vec::with_capacity(order.len());
let mut nulls_first = Vec::with_capacity(order.len());

for SortOrder {
child,
direction,
null_ordering,
} in order
{
let expr = child.required("child")?;
let expr = to_daft_expr(&expr)?;

let sort_direction = SortDirection::try_from(direction)
.wrap_err_with(|| format!("Invalid sort direction: {direction}"))?;

let desc = match sort_direction {
SortDirection::Ascending => false,
SortDirection::Descending | SortDirection::Unspecified => true,
};

let null_ordering = NullOrdering::try_from(null_ordering)
.wrap_err_with(|| format!("Invalid sort nulls: {null_ordering}"))?;

let nf = match null_ordering {
NullOrdering::SortNullsUnspecified => desc,
NullOrdering::SortNullsFirst => true,
NullOrdering::SortNullsLast => false,
};

sort_by.push(expr);
descending.push(desc);
nulls_first.push(nf);
}

plan.sort(sort_by, descending, nulls_first)
.map_err(Into::into)
}
}

fn rel_name(rel: &RelType) -> &str {
Expand Down
13 changes: 13 additions & 0 deletions tests/connect/test_distinct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations

from pyspark.sql import Row


def test_distinct(spark_session):
# Create simple DataFrame with single column
data = [(1,), (2,), (1,)]
df = spark_session.createDataFrame(data, ["id"]).distinct()

assert df.count() == 2, "DataFrame should have 2 rows"

assert df.sort().collect() == [Row(id=1), Row(id=2)], "DataFrame should contain expected values"

0 comments on commit 34d2036

Please sign in to comment.