Skip to content
2 changes: 2 additions & 0 deletions giskard/scanner/robustness/text_perturbation_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class TextPerturbationDetector(BaseTextPerturbationDetector):

def _get_default_transformations(self, model: BaseModel, dataset: Dataset) -> Sequence[TextTransformation]:
from .text_transformations import (
TextAccentRemovalTransformation,
TextLowercase,
TextPunctuationRemovalTransformation,
TextTitleCase,
Expand All @@ -38,4 +39,5 @@ def _get_default_transformations(self, model: BaseModel, dataset: Dataset) -> Se
TextTitleCase,
TextTypoTransformation,
TextPunctuationRemovalTransformation,
TextAccentRemovalTransformation,
]
17 changes: 17 additions & 0 deletions giskard/scanner/robustness/text_transformations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
import json
import re
import unicodedata
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -193,6 +194,22 @@ def make_perturbation(self, text):
return "".join(pieces)


class TextAccentRemovalTransformation(TextTransformation):
name = "Accent Removal"

def __init__(self, column, rate=1.0, rng_seed=1729):
super().__init__(column)
self.rate = rate
self.rng = np.random.default_rng(seed=rng_seed)

def make_perturbation(self, text):
return "".join(
char
for char in unicodedata.normalize("NFD", text)
if unicodedata.category(char) != "Mn" or self.rng.random() > self.rate
)


class TextLanguageBasedTransformation(TextTransformation):
needs_dataset = True

Expand Down
27 changes: 27 additions & 0 deletions tests/scan/test_text_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,33 @@ def test_punctuation_strip_transformation():
assert transformed_text[5] == "comma separated list"


def test_accent_removal_transformation():
dataset = _dataset_from_dict(
{
"text": [
"C'est l'été",
"çà et là",
"Tiếng Việt",
"État",
"你好",
]
}
)

from giskard.scanner.robustness.text_transformations import TextAccentRemovalTransformation

t = TextAccentRemovalTransformation(column="text")

transformed = dataset.transform(t)
transformed_text = transformed.df.text.values

assert transformed_text[0] == "C'est l'ete"
assert transformed_text[1] == "ca et la"
assert transformed_text[2] == "Tieng Viet"
assert transformed_text[3] == "Etat"
assert transformed_text[4] == "你好"


def test_religion_based_transformation():
dataset = _dataset_from_dict(
{
Expand Down