Skip to content

Commit

Permalink
feat(sql): Adds url_download and url_upload to daft-sql (#3690)
Browse files Browse the repository at this point in the history
  • Loading branch information
RCHowell authored Jan 16, 2025
1 parent 34d2036 commit c650794
Show file tree
Hide file tree
Showing 15 changed files with 382 additions and 89 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,7 @@ def upload(
multi_thread = ExpressionUrlNamespace._should_use_multithreading_tokio_runtime()
# If the user specifies a single location via a string, we should upload to a single folder. Otherwise,
# if the user gave an expression, we assume that each row has a specific url to upload to.
# Consider moving the check for is_single_folder to a lower IR.
is_single_folder = isinstance(location, str)
io_config = ExpressionUrlNamespace._override_io_config_max_connections(max_connections, io_config)
return Expression._from_pyexpr(
Expand Down
2 changes: 0 additions & 2 deletions src/daft-functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ arrow2 = {workspace = true}
base64 = {workspace = true}
common-error = {path = "../common/error", default-features = false}
common-hashable-float-wrapper = {path = "../common/hashable-float-wrapper"}
common-io-config = {path = "../common/io-config", default-features = false}
common-runtime = {path = "../common/runtime", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
daft-dsl = {path = "../daft-dsl", default-features = false}
Expand All @@ -25,7 +24,6 @@ snafu.workspace = true
[features]
python = [
"common-error/python",
"common-io-config/python",
"daft-core/python",
"daft-dsl/python",
"daft-image/python",
Expand Down
18 changes: 8 additions & 10 deletions src/daft-functions/src/python/uri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use daft_dsl::python::PyExpr;
use daft_io::python::IOConfig;
use pyo3::{exceptions::PyValueError, pyfunction, PyResult};

use crate::uri::{self, download::UrlDownloadArgs, upload::UrlUploadArgs};

#[pyfunction]
pub fn url_download(
expr: PyExpr,
Expand All @@ -15,15 +17,13 @@ pub fn url_download(
"max_connections must be positive and non_zero: {max_connections}"
)));
}

Ok(crate::uri::download(
expr.into(),
let args = UrlDownloadArgs::new(
max_connections as usize,
raise_error_on_failure,
multi_thread,
Some(config.config),
)
.into())
);
Ok(uri::download(expr.into(), Some(args)).into())
}

#[pyfunction(signature = (
Expand All @@ -49,14 +49,12 @@ pub fn url_upload(
"max_connections must be positive and non_zero: {max_connections}"
)));
}
Ok(crate::uri::upload(
expr.into(),
folder_location.into(),
let args = UrlUploadArgs::new(
max_connections as usize,
raise_error_on_failure,
multi_thread,
is_single_folder,
io_config.map(|io_config| io_config.config),
)
.into())
);
Ok(uri::upload(expr.into(), folder_location.into(), Some(args)).into())
}
52 changes: 44 additions & 8 deletions src/daft-functions/src/uri/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,52 @@ use snafu::prelude::*;

use crate::InvalidArgumentSnafu;

/// Container for the keyword arguments of `url_download`
/// ex:
/// ```text
/// url_decode(input)
/// url_decode(input, max_connections=32)
/// url_decode(input, on_error='raise')
/// url_decode(input, on_error='null')
/// url_decode(input, max_connections=32, on_error='raise')
/// ```
#[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, Eq, Hash)]
pub(super) struct DownloadFunction {
pub(super) max_connections: usize,
pub(super) raise_error_on_failure: bool,
pub(super) multi_thread: bool,
pub(super) config: Arc<IOConfig>,
pub struct UrlDownloadArgs {
pub max_connections: usize,
pub raise_error_on_failure: bool,
pub multi_thread: bool,
pub io_config: Arc<IOConfig>,
}

impl UrlDownloadArgs {
pub fn new(
max_connections: usize,
raise_error_on_failure: bool,
multi_thread: bool,
io_config: Option<IOConfig>,
) -> Self {
Self {
max_connections,
raise_error_on_failure,
multi_thread,
io_config: io_config.unwrap_or_default().into(),
}
}
}

impl Default for UrlDownloadArgs {
fn default() -> Self {
Self {
max_connections: 32,
raise_error_on_failure: true,
multi_thread: true,
io_config: IOConfig::default().into(),
}
}
}

#[typetag::serde]
impl ScalarUDF for DownloadFunction {
impl ScalarUDF for UrlDownloadArgs {
fn as_any(&self) -> &dyn std::any::Any {
self
}
Expand All @@ -34,7 +70,7 @@ impl ScalarUDF for DownloadFunction {
max_connections,
raise_error_on_failure,
multi_thread,
config,
io_config,
} = self;

match inputs {
Expand All @@ -47,7 +83,7 @@ impl ScalarUDF for DownloadFunction {
*max_connections,
*raise_error_on_failure,
*multi_thread,
config.clone(),
io_config.clone(),
Some(io_stats),
)?;
Ok(result.into_series())
Expand Down
52 changes: 10 additions & 42 deletions src/daft-functions/src/uri/mod.rs
Original file line number Diff line number Diff line change
@@ -1,50 +1,18 @@
mod download;
mod upload;
pub mod download;
pub mod upload;

use common_io_config::IOConfig;
use daft_dsl::{functions::ScalarFunction, ExprRef};
use download::DownloadFunction;
use upload::UploadFunction;
use download::UrlDownloadArgs;
use upload::UrlUploadArgs;

/// Creates a `url_download` ExprRef from the positional and optional named arguments.
#[must_use]
pub fn download(
input: ExprRef,
max_connections: usize,
raise_error_on_failure: bool,
multi_thread: bool,
config: Option<IOConfig>,
) -> ExprRef {
ScalarFunction::new(
DownloadFunction {
max_connections,
raise_error_on_failure,
multi_thread,
config: config.unwrap_or_default().into(),
},
vec![input],
)
.into()
pub fn download(input: ExprRef, args: Option<UrlDownloadArgs>) -> ExprRef {
ScalarFunction::new(args.unwrap_or_default(), vec![input]).into()
}

/// Creates a `url_upload` ExprRef from the positional and optional named arguments.
#[must_use]
pub fn upload(
input: ExprRef,
location: ExprRef,
max_connections: usize,
raise_error_on_failure: bool,
multi_thread: bool,
is_single_folder: bool,
config: Option<IOConfig>,
) -> ExprRef {
ScalarFunction::new(
UploadFunction {
max_connections,
raise_error_on_failure,
multi_thread,
is_single_folder,
config: config.unwrap_or_default().into(),
},
vec![input, location],
)
.into()
pub fn upload(input: ExprRef, location: ExprRef, args: Option<UrlUploadArgs>) -> ExprRef {
ScalarFunction::new(args.unwrap_or_default(), vec![input, location]).into()
}
48 changes: 39 additions & 9 deletions src/daft-functions/src/uri/upload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,46 @@ use futures::{StreamExt, TryStreamExt};
use serde::Serialize;

#[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, Eq, Hash)]
pub(super) struct UploadFunction {
pub(super) max_connections: usize,
pub(super) raise_error_on_failure: bool,
pub(super) multi_thread: bool,
pub(super) is_single_folder: bool,
pub(super) config: Arc<IOConfig>,
pub struct UrlUploadArgs {
pub max_connections: usize,
pub raise_error_on_failure: bool,
pub multi_thread: bool,
pub is_single_folder: bool,
pub io_config: Arc<IOConfig>,
}

impl UrlUploadArgs {
pub fn new(
max_connections: usize,
raise_error_on_failure: bool,
multi_thread: bool,
is_single_folder: bool,
io_config: Option<IOConfig>,
) -> Self {
Self {
max_connections,
raise_error_on_failure,
multi_thread,
is_single_folder,
io_config: io_config.unwrap_or_default().into(),
}
}
}

impl Default for UrlUploadArgs {
fn default() -> Self {
Self {
max_connections: 32,
raise_error_on_failure: true,
multi_thread: true,
is_single_folder: false,
io_config: IOConfig::default().into(),
}
}
}

#[typetag::serde]
impl ScalarUDF for UploadFunction {
impl ScalarUDF for UrlUploadArgs {
fn as_any(&self) -> &dyn std::any::Any {
self
}
Expand All @@ -29,11 +59,11 @@ impl ScalarUDF for UploadFunction {

fn evaluate(&self, inputs: &[Series]) -> DaftResult<Series> {
let Self {
config,
max_connections,
raise_error_on_failure,
multi_thread,
is_single_folder,
io_config,
} = self;

match inputs {
Expand All @@ -44,7 +74,7 @@ impl ScalarUDF for UploadFunction {
*raise_error_on_failure,
*multi_thread,
*is_single_folder,
config.clone(),
io_config.clone(),
None,
),
_ => Err(DaftError::ValueError(format!(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ mod tests {
use common_scan_info::Pushdowns;
use daft_core::prelude::*;
use daft_dsl::{col, lit};
use daft_functions::uri::download::UrlDownloadArgs;
use rstest::rstest;

use crate::{
Expand Down Expand Up @@ -435,7 +436,10 @@ mod tests {
/// Tests that we can't pushdown a filter into a ScanOperator if it has an udf-ish expression.
#[test]
fn filter_with_udf_not_pushed_down_into_scan() -> DaftResult<()> {
let pred = daft_functions::uri::download(col("a"), 1, true, true, None);
let pred = daft_functions::uri::download(
col("a"),
Some(UrlDownloadArgs::new(1, true, true, None)),
);
let plan = dummy_scan_node(dummy_scan_operator(vec![
Field::new("a", DataType::Int64),
Field::new("b", DataType::Utf8),
Expand Down
31 changes: 30 additions & 1 deletion src/daft-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
coalesce::SQLCoalesce, hashing, SQLModule, SQLModuleAggs, SQLModuleConfig, SQLModuleFloat,
SQLModuleImage, SQLModuleJson, SQLModuleList, SQLModuleMap, SQLModuleNumeric,
SQLModulePartitioning, SQLModulePython, SQLModuleSketch, SQLModuleStructs,
SQLModuleTemporal, SQLModuleUtf8,
SQLModuleTemporal, SQLModuleUri, SQLModuleUtf8,
},
planner::SQLPlanner,
unsupported_sql_err,
Expand All @@ -36,6 +36,7 @@ pub(crate) static SQL_FUNCTIONS: Lazy<SQLFunctions> = Lazy::new(|| {
functions.register::<SQLModuleSketch>();
functions.register::<SQLModuleStructs>();
functions.register::<SQLModuleTemporal>();
functions.register::<SQLModuleUri>();
functions.register::<SQLModuleUtf8>();
functions.register::<SQLModuleConfig>();
functions.add_fn("coalesce", SQLCoalesce {});
Expand Down Expand Up @@ -375,3 +376,31 @@ impl<'a> SQLPlanner<'a> {
}
}
}

/// A namespace for function argument parsing helpers.
pub(crate) mod args {
use common_io_config::IOConfig;

use super::SQLFunctionArguments;
use crate::{error::PlannerError, modules::config::expr_to_iocfg, unsupported_sql_err};

/// Parses on_error => Literal['raise', 'null'] = 'raise' or err.
pub(crate) fn parse_on_error(args: &SQLFunctionArguments) -> Result<bool, PlannerError> {
match args.try_get_named::<String>("on_error")?.as_deref() {
None => Ok(true),
Some("raise") => Ok(true),
Some("null") => Ok(false),
Some(other) => {
unsupported_sql_err!("Expected on_error to be 'raise' or 'null', found '{other}'")
}
}
}

/// Parses io_config which is used in several SQL functions.
pub(crate) fn parse_io_config(args: &SQLFunctionArguments) -> Result<IOConfig, PlannerError> {
args.get_named("io_config")
.map(expr_to_iocfg)
.transpose()
.map(|op| op.unwrap_or_default())
}
}
17 changes: 2 additions & 15 deletions src/daft-sql/src/modules/image/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use sqlparser::ast::FunctionArg;

use crate::{
error::{PlannerError, SQLPlannerResult},
functions::{SQLFunction, SQLFunctionArguments},
functions::{self, SQLFunction, SQLFunctionArguments},
unsupported_sql_err,
};

Expand All @@ -21,20 +21,7 @@ impl TryFrom<SQLFunctionArguments> for ImageDecode {
_ => unsupported_sql_err!("Expected mode to be a string"),
})
.transpose()?;

let raise_on_error = args
.get_named("on_error")
.map(|arg| match arg.as_ref() {
Expr::Literal(LiteralValue::Utf8(s)) => match s.as_ref() {
"raise" => Ok(true),
"null" => Ok(false),
_ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"),
},
_ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"),
})
.transpose()?
.unwrap_or(true);

let raise_on_error = functions::args::parse_on_error(&args)?;
Ok(Self {
mode,
raise_on_error,
Expand Down
Loading

0 comments on commit c650794

Please sign in to comment.