Skip to content

Commit b5a9735

Browse files
GH963 Restrict callable to hashableT for read_csv (#1234)
* GH963 Restrict callable to hahableT for read_csv * GH963 Add test
1 parent f8a329d commit b5a9735

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

‎pandas-stubs/io/parsers/readers.pyi

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ from pandas._typing import (
2727
DtypeArg,
2828
DtypeBackend,
2929
FilePath,
30+
HashableT,
3031
ListLikeHashable,
3132
ReadCsvBuffer,
3233
StorageOptions,
@@ -44,7 +45,7 @@ def read_csv(
4445
header: int | Sequence[int] | Literal["infer"] | None = ...,
4546
names: ListLikeHashable | None = ...,
4647
index_col: int | str | Sequence[str | int] | Literal[False] | None = ...,
47-
usecols: UsecolsArgType = ...,
48+
usecols: UsecolsArgType[HashableT] = ...,
4849
dtype: DtypeArg | defaultdict | None = ...,
4950
engine: CSVEngine | None = ...,
5051
converters: (
@@ -108,7 +109,7 @@ def read_csv(
108109
header: int | Sequence[int] | Literal["infer"] | None = ...,
109110
names: ListLikeHashable | None = ...,
110111
index_col: int | str | Sequence[str | int] | Literal[False] | None = ...,
111-
usecols: UsecolsArgType = ...,
112+
usecols: UsecolsArgType[HashableT] = ...,
112113
dtype: DtypeArg | defaultdict | None = ...,
113114
engine: CSVEngine | None = ...,
114115
converters: (
@@ -172,7 +173,7 @@ def read_csv(
172173
header: int | Sequence[int] | Literal["infer"] | None = ...,
173174
names: ListLikeHashable | None = ...,
174175
index_col: int | str | Sequence[str | int] | Literal[False] | None = ...,
175-
usecols: UsecolsArgType = ...,
176+
usecols: UsecolsArgType[HashableT] = ...,
176177
dtype: DtypeArg | defaultdict | None = ...,
177178
engine: CSVEngine | None = ...,
178179
converters: (
@@ -236,7 +237,7 @@ def read_table(
236237
header: int | Sequence[int] | Literal["infer"] | None = ...,
237238
names: ListLikeHashable | None = ...,
238239
index_col: int | str | Sequence[str | int] | Literal[False] | None = ...,
239-
usecols: UsecolsArgType = ...,
240+
usecols: UsecolsArgType[HashableT] = ...,
240241
dtype: DtypeArg | defaultdict | None = ...,
241242
engine: CSVEngine | None = ...,
242243
converters: (
@@ -300,7 +301,7 @@ def read_table(
300301
header: int | Sequence[int] | Literal["infer"] | None = ...,
301302
names: ListLikeHashable | None = ...,
302303
index_col: int | str | Sequence[str | int] | Literal[False] | None = ...,
303-
usecols: UsecolsArgType = ...,
304+
usecols: UsecolsArgType[HashableT] = ...,
304305
dtype: DtypeArg | defaultdict | None = ...,
305306
engine: CSVEngine | None = ...,
306307
converters: (
@@ -364,7 +365,7 @@ def read_table(
364365
header: int | Sequence[int] | Literal["infer"] | None = ...,
365366
names: ListLikeHashable | None = ...,
366367
index_col: int | str | Sequence[str | int] | Literal[False] | None = ...,
367-
usecols: UsecolsArgType = ...,
368+
usecols: UsecolsArgType[HashableT] = ...,
368369
dtype: DtypeArg | defaultdict | None = ...,
369370
engine: CSVEngine | None = ...,
370371
converters: (

‎tests/test_io.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,11 @@ def test_read_csv():
575575
DataFrame,
576576
)
577577

578+
def cols(x: str) -> bool:
579+
return x in ["a", "b"]
580+
581+
pd.read_csv(path, usecols=cols)
582+
578583

579584
def test_read_csv_iterator():
580585
with ensure_clean() as path:
@@ -727,6 +732,11 @@ def test_types_read_csv() -> None:
727732
pd.read_csv(path, names="abcd") # type: ignore[call-overload] # pyright: ignore[reportArgumentType]
728733
pd.read_csv(path, usecols="abcd") # type: ignore[call-overload] # pyright: ignore[reportArgumentType]
729734

735+
def cols2(x: set[float]) -> bool:
736+
return sum(x) < 1.0
737+
738+
pd.read_csv("file.csv", usecols=cols2) # type: ignore[type-var] # pyright: ignore[reportArgumentType]
739+
730740
tfr1 = pd.read_csv(path, nrows=2, iterator=True, chunksize=3)
731741
check(assert_type(tfr1, TextFileReader), TextFileReader)
732742
tfr1.close()

0 commit comments

Comments
 (0)