Skip to content

Commit 9806c17

Browse files
authored
feat: add redshift and rds data api query params (#3111)
* feat: add redshift data api params * mypy * fix expected value * add params to rds data api
1 parent bcb1041 commit 9806c17

File tree

3 files changed

+92
-6
lines changed

3 files changed

+92
-6
lines changed

‎awswrangler/data_api/rds.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,9 @@ def connect(
256256
return RdsDataApi(resource_arn, database, secret_arn=secret_arn, boto3_session=boto3_session, **kwargs)
257257

258258

259-
def read_sql_query(sql: str, con: RdsDataApi, database: str | None = None) -> pd.DataFrame:
259+
def read_sql_query(
260+
sql: str, con: RdsDataApi, database: str | None = None, parameters: list[dict[str, Any]] | None = None
261+
) -> pd.DataFrame:
260262
"""Run an SQL query on an RdsDataApi connection and return the result as a DataFrame.
261263
262264
Parameters
@@ -267,12 +269,31 @@ def read_sql_query(sql: str, con: RdsDataApi, database: str | None = None) -> pd
267269
A RdsDataApi connection instance
268270
database
269271
Database to run query on - defaults to the database specified by `con`.
272+
parameters
273+
A list of named parameters e.g. [{"name": "col", "value": {"stringValue": "val1"}}].
270274
271275
Returns
272276
-------
273277
A Pandas DataFrame containing the query results.
278+
279+
Examples
280+
--------
281+
>>> import awswrangler as wr
282+
>>> df = wr.data_api.rds.read_sql_query(
283+
>>> sql="SELECT * FROM public.my_table",
284+
>>> con=con,
285+
>>> )
286+
287+
>>> import awswrangler as wr
288+
>>> df = wr.data_api.rds.read_sql_query(
289+
>>> sql="SELECT * FROM public.my_table WHERE col = :name",
290+
>>> con=con,
291+
>>> parameters=[
292+
>>> {"name": "col1", "value": {"stringValue": "val1"}}
293+
>>> ],
294+
>>> )
274295
"""
275-
return con.execute(sql, database=database)
296+
return con.execute(sql, database=database, parameters=parameters)
276297

277298

278299
def _drop_table(con: RdsDataApi, table: str, database: str, transaction_id: str, sql_mode: str) -> None:

‎awswrangler/data_api/redshift.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,6 @@ def _execute_statement(
112112
) -> str:
113113
if transaction_id:
114114
raise exceptions.InvalidArgument("`transaction_id` not supported for Redshift Data API")
115-
if parameters:
116-
raise exceptions.InvalidArgument("`parameters` not supported for Redshift Data API")
117115

118116
self._validate_redshift_target()
119117
self._validate_auth_method()
@@ -130,6 +128,8 @@ def _execute_statement(
130128
args["ClusterIdentifier"] = self.cluster_id
131129
if self.workgroup_name:
132130
args["WorkgroupName"] = self.workgroup_name
131+
if parameters:
132+
args["Parameters"] = parameters # type: ignore[assignment]
133133

134134
_logger.debug("Executing %s", sql)
135135
response = self.client.execute_statement(
@@ -285,7 +285,12 @@ def connect(
285285
)
286286

287287

288-
def read_sql_query(sql: str, con: RedshiftDataApi, database: str | None = None) -> pd.DataFrame:
288+
def read_sql_query(
289+
sql: str,
290+
con: RedshiftDataApi,
291+
database: str | None = None,
292+
parameters: list[dict[str, Any]] | None = None,
293+
) -> pd.DataFrame:
289294
"""Run an SQL query on a RedshiftDataApi connection and return the result as a DataFrame.
290295
291296
Parameters
@@ -296,9 +301,29 @@ def read_sql_query(sql: str, con: RedshiftDataApi, database: str | None = None)
296301
A RedshiftDataApi connection instance
297302
database
298303
Database to run query on - defaults to the database specified by `con`.
304+
parameters
305+
A list of named parameters e.g. [{"name": "id", "value": "42"}].
299306
300307
Returns
301308
-------
302309
A Pandas DataFrame containing the query results.
310+
311+
Examples
312+
--------
313+
>>> import awswrangler as wr
314+
>>> df = wr.data_api.redshift.read_sql_query(
315+
>>> sql="SELECT * FROM public.my_table",
316+
>>> con=con,
317+
>>> )
318+
319+
>>> import awswrangler as wr
320+
>>> df = wr.data_api.redshift.read_sql_query(
321+
>>> sql="SELECT * FROM public.my_table WHERE id >= :id",
322+
>>> con=con,
323+
>>> parameters=[
324+
>>> {"name": "id", "value": "42"},
325+
>>> ],
326+
>>> )
327+
303328
"""
304-
return con.execute(sql, database=database)
329+
return con.execute(sql, database=database, parameters=parameters)

‎tests/unit/test_data_api.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,24 @@ def test_data_api_redshift_basic_select(redshift_connector: "RedshiftDataApi", r
124124
assert_pandas_equals(dataframe, expected_dataframe)
125125

126126

127+
def test_data_api_redshift_parameters(redshift_connector: "RedshiftDataApi", redshift_table: str) -> None:
128+
wr.data_api.redshift.read_sql_query(
129+
f"CREATE TABLE public.{redshift_table} (id INT, name VARCHAR)", con=redshift_connector
130+
)
131+
wr.data_api.redshift.read_sql_query(
132+
f"INSERT INTO public.{redshift_table} VALUES (41, 'test1'), (42, 'test2')", con=redshift_connector
133+
)
134+
expected_dataframe = pd.DataFrame([[42, "test2"]], columns=["id", "name"])
135+
136+
dataframe = wr.data_api.redshift.read_sql_query(
137+
f"SELECT * FROM public.{redshift_table} WHERE id >= :id",
138+
con=redshift_connector,
139+
parameters=[{"name": "id", "value": "42"}],
140+
)
141+
142+
assert_pandas_equals(dataframe, expected_dataframe)
143+
144+
127145
def test_data_api_redshift_empty_results_select(redshift_connector: "RedshiftDataApi", redshift_table: str) -> None:
128146
wr.data_api.redshift.read_sql_query(
129147
f"CREATE TABLE public.{redshift_table} (id INT, name VARCHAR)", con=redshift_connector
@@ -301,3 +319,25 @@ def test_data_api_postgresql(postgresql_serverless_connector: "RdsDataApi", post
301319
)
302320
expected_dataframe = pd.DataFrame([["test"]], columns=["name"])
303321
assert_pandas_equals(out_frame, expected_dataframe)
322+
323+
324+
def test_data_api_mysql_parameters(
325+
mysql_serverless_connector: "RdsDataApi",
326+
mysql_serverless_table: str,
327+
) -> None:
328+
database = "test"
329+
df = pd.DataFrame([[42, "test"]], columns=["id", "name"])
330+
331+
wr.data_api.rds.to_sql(
332+
df=df,
333+
con=mysql_serverless_connector,
334+
table=mysql_serverless_table,
335+
database=database,
336+
)
337+
338+
out_df = wr.data_api.rds.read_sql_query(
339+
f"SELECT * FROM {database}.{mysql_serverless_table} WHERE name = :name",
340+
con=mysql_serverless_connector,
341+
parameters=[{"name": "name", "value": {"stringValue": "test"}}],
342+
)
343+
assert_pandas_equals(out_df, df)

0 commit comments

Comments
 (0)