Skip to content

Commit a6faaf3

Browse files
GH1182 Add numpy integer, floating, complex to Scalar (#1192)
1 parent 351810b commit a6faaf3

File tree

3 files changed

+15
-1
lines changed

3 files changed

+15
-1
lines changed

‎pandas-stubs/_typing.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,9 @@ IndexIterScalar: TypeAlias = (
518518
| Timestamp
519519
| Timedelta
520520
)
521-
Scalar: TypeAlias = IndexIterScalar | complex
521+
Scalar: TypeAlias = (
522+
IndexIterScalar | complex | np.integer | np.floating | np.complexfloating
523+
)
522524
ScalarT = TypeVar("ScalarT", bound=Scalar)
523525
# Refine the definitions below in 3.9 to use the specialized type.
524526
np_ndarray_int64: TypeAlias = npt.NDArray[np.int64]

‎tests/test_frame.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,6 +1531,15 @@ def test_types_to_numpy() -> None:
15311531
# na_value param was added in 1.1.0 https://pandas.pydata.org/docs/whatsnew/v1.1.0.html
15321532
check(assert_type(df.to_numpy(na_value=0), np.ndarray), np.ndarray)
15331533

1534+
df = pd.DataFrame(data={"col1": [1, 1, 2]}, dtype=np.complex128)
1535+
check(assert_type(df.to_numpy(na_value=0), np.ndarray), np.ndarray)
1536+
check(assert_type(df.to_numpy(na_value=np.int32(4)), np.ndarray), np.ndarray)
1537+
check(assert_type(df.to_numpy(na_value=np.float16(3.68)), np.ndarray), np.ndarray)
1538+
check(
1539+
assert_type(df.to_numpy(na_value=np.complex128(3.8, -493.2)), np.ndarray),
1540+
np.ndarray,
1541+
)
1542+
15341543

15351544
def test_to_markdown() -> None:
15361545
df = pd.DataFrame(data={"col1": [1, 1, 2], "col2": [3, 4, 5]})

‎tests/test_series.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,6 +1749,9 @@ def test_types_to_numpy() -> None:
17491749
check(assert_type(s.to_numpy(), np.ndarray), np.ndarray)
17501750
check(assert_type(s.to_numpy(dtype="str", copy=True), np.ndarray), np.ndarray)
17511751
check(assert_type(s.to_numpy(na_value=0), np.ndarray), np.ndarray)
1752+
check(assert_type(s.to_numpy(na_value=np.int32(4)), np.ndarray), np.ndarray)
1753+
check(assert_type(s.to_numpy(na_value=np.float16(4)), np.ndarray), np.ndarray)
1754+
check(assert_type(s.to_numpy(na_value=np.complex128(4, 7)), np.ndarray), np.ndarray)
17521755

17531756

17541757
def test_where() -> None:

0 commit comments

Comments
 (0)