Skip to content

Conversation

@volker48
Copy link
Contributor

@volker48 volker48 commented Jun 4, 2025

I wanted to be able to pass along a weight Tensor to better handle working with imbalanced class data.

Copy link
Contributor

@stephantul stephantul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thanks for your contribution. You don't actually send the weight to the lightningmodule, so it never gets used. (see line 209).

Changes I'd like to see added:
A check: the weight also needs to have length C, the number of unique classes. If the number of unique labels is not len(C), this should throw a ValueError. If you do this after initialize, the unique labels should have been stored already, for both the single- and multilabel case.

A test: check the above condition. Maybe check if it correctly propagates to the module, but no need to run some training.

@volker48
Copy link
Contributor Author

volker48 commented Jun 5, 2025

Good catch 🤦‍♂️. A sanity check verification on weight dimensions is a solid idea as well. I'll update.

@volker48
Copy link
Contributor Author

volker48 commented Jun 5, 2025

@stephantul ok I made a few updates here:

  1. renamed weight -> class_weight for clarity
  2. added a check to make sure the weight has length equal to number of classes
  3. added a test to make sure this error is thrown if class_weight has incorrect weight
  4. had to make a small change to the state dict processing code to remove the loss_function.weight otherwise there was an error since the loss_function is not part of the module https://github.com/MinishLab/model2vec/pull/260/files#diff-42ac83d7a74d22cbee35cd49306fc47bd07f0352eb0b67c994d663c08d5af9ecR252. Discovered this issue in the test so 👍 to testing 🤣.
Copy link
Contributor

@stephantul stephantul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks really nice!

@Pringled Pringled merged commit 4867cb8 into MinishLab:main Jun 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants