Skip to content

[python-package] Dataset sampling from Sequences should use ranges, rather than single row indices, for performance #7006

@nmusolino

Description

@nmusolino

Summary

Dataset objects can be constructed from a list of lgb.Sequence objects. When the dataset samples the data in the list of Sequence objects, it accesses one row of data at a time, by passing a single integer index to the Sequence object (at basic.py:2218).

The Sequence base class also supports indexes by slice or by List[int]:

   @abc.abstractmethod
    def __getitem__(self, idx: Union[int, slice, List[int]]) -> np.ndarray:
        """Return data for given row index.  [...]"""

Dataset sampling should perform indexing with slice objects or lists of int, under the assumption that user-defined Sequence classes can more efficiently return rows this way.

Motivation

In #4089, LightGBM added a Sequence base class, so that users can develop custom data accessors. A user-defined Sequence class should implement random access by row index, or range access by slice or list[int].

When the Dataset class is constructed with a Sequence object or a list thereof, it samples the data, selecting bin_construct_sample_cnt row indices and accessing the rows through the Sequence API.

For some Sequence classes, this row-by-row access may be less efficient than slice- or list[int]-based access. In my experience, this was the rate-limiting step in constructing Dataset objects. The Python sampling methods should use the range-based access for better performance.

Description

When the user passes a list of Sequence objects, the Dataset class should obtain a list of indices through _create_sample_indices(), as it does now. It should then group these indices (e.g. using itertools.takewhile()) into the indices for each Sequence object, up to batch_size indices per group. After translation into sequence-specific indices, these can be passed to the Sequence object.

As a special case, when the sample count (i.e. length of indices returned by _create_sample_indices()) is equal to the total row count, we could access all the rows of each Sequence with slice objects. For reference, the _init_from_seqs() method currently does this when reading through the data after sampling:

batch_size = getattr(seq, "batch_size", None) or Sequence.batch_size
for start in range(0, nrow, batch_size):
end = min(start + batch_size, nrow)
self._push_rows(seq[start:end])

After this change, we may have to tweak the concatenation of sampled rows:

# Select sampled rows, transpose to column order.
sampled = np.array(list(self._yield_row_from_seqlist(seqs, indices)))
sampled = sampled.T

to handle two-dimensional results.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions