Skip to content

Commit f1ce938

Browse files
authored
Test rm (#373)
* test: add integration tests for model remove * test: skip digest-based references in model removal tests * refactor: update removeModel function to include force parameter * improve integration tests for model removal functionality * fix: enhance model ID expansion logic and update cleanup behavior in tests * fix: handle digest references in model normalization and deletion logic
1 parent afd9517 commit f1ce938

File tree

3 files changed

+270
-9
lines changed

3 files changed

+270
-9
lines changed

‎cmd/cli/commands/integration_test.go‎

Lines changed: 244 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,8 @@ func dockerModelRunner(t *testing.T, ctx context.Context, net *testcontainers.Do
185185
}
186186

187187
// removeModel removes a model from the local store
188-
func removeModel(client *desktop.Client, modelID string) error {
189-
_, err := client.Remove([]string{modelID}, true)
188+
func removeModel(client *desktop.Client, modelID string, force bool) error {
189+
_, err := client.Remove([]string{modelID}, force)
190190
return err
191191
}
192192

@@ -400,7 +400,7 @@ func TestIntegration_PullModel(t *testing.T) {
400400

401401
// Clean up: remove the model for the next test iteration
402402
t.Logf("Removing model %s", truncatedID)
403-
err = removeModel(env.client, tc.expectedModelID)
403+
err = removeModel(env.client, tc.expectedModelID, true)
404404
require.NoError(t, err, "Failed to remove model")
405405
})
406406
}
@@ -456,7 +456,7 @@ func TestIntegration_InspectModel(t *testing.T) {
456456

457457
// Cleanup: remove the model
458458
t.Logf("Removing model %s", truncatedID)
459-
err = removeModel(env.client, modelID)
459+
err = removeModel(env.client, modelID, true)
460460
require.NoError(t, err, "Failed to remove model")
461461

462462
// Verify model was removed
@@ -621,7 +621,7 @@ func TestIntegration_TagModel(t *testing.T) {
621621

622622
// Cleanup: remove the model
623623
t.Logf("Removing model %s", truncatedID)
624-
err = removeModel(env.client, modelID)
624+
err = removeModel(env.client, modelID, true)
625625
require.NoError(t, err, "Failed to remove model")
626626

627627
// Verify model was removed
@@ -768,11 +768,249 @@ func TestIntegration_PushModel(t *testing.T) {
768768

769769
// Final cleanup: remove the model
770770
t.Logf("Removing model %s", truncatedID)
771-
err = removeModel(env.client, modelID)
771+
err = removeModel(env.client, modelID, true)
772772
require.NoError(t, err, "Failed to remove model")
773773

774774
// Verify model was removed
775775
models, err = listModels(false, env.client, true, false, "")
776776
require.NoError(t, err)
777777
require.Empty(t, strings.TrimSpace(models), "Model should be removed")
778778
}
779+
780+
// TestIntegration_RemoveModel tests removing models with various reference formats
781+
// to ensure proper reference normalization and correct removal behavior.
782+
func TestIntegration_RemoveModel(t *testing.T) {
783+
env := setupTestEnv(t)
784+
785+
// Ensure no models exist initially
786+
models, err := listModels(false, env.client, true, false, "")
787+
require.NoError(t, err)
788+
if len(models) != 0 {
789+
t.Fatal("Expected no initial models, but found some")
790+
}
791+
792+
// Create and push a test model with default org (ai/rm-test:latest)
793+
modelRef := "ai/rm-test:latest"
794+
modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, 2048)
795+
t.Logf("Test model pushed: %s (ID: %s) FQDN: %s Digest: %s", hostFQDN, modelID, networkFQDN, digest)
796+
797+
// Generate all reference test cases
798+
info := modelInfo{
799+
name: "rm-test",
800+
org: "ai",
801+
tag: "latest",
802+
registry: "registry.local:5000",
803+
modelID: modelID,
804+
digest: digest,
805+
expectedName: "ai/rm-test:latest",
806+
}
807+
testCases := generateReferenceTestCases(info)
808+
809+
// Remove model using various reference formats
810+
t.Run("remove with various reference formats", func(t *testing.T) {
811+
for _, tc := range testCases {
812+
813+
t.Run(tc.name, func(t *testing.T) {
814+
// Pull the model
815+
pullRef := "rm-test"
816+
t.Logf("Pulling model with reference: %s", pullRef)
817+
err := pullModel(newPullCmd(), env.client, pullRef, true)
818+
require.NoError(t, err, "Failed to pull model")
819+
820+
// Verify model exists
821+
models, err := listModels(false, env.client, true, false, "")
822+
require.NoError(t, err)
823+
truncatedID := modelID[7:19]
824+
require.Equal(t, truncatedID, strings.TrimSpace(models), "Model not found after pull")
825+
826+
// Remove using the test case reference
827+
t.Logf("Removing model with reference: %s", tc.ref)
828+
err = removeModel(env.client, tc.ref, false)
829+
require.NoError(t, err, "Failed to remove model with reference: %s", tc.ref)
830+
831+
// Verify model is removed
832+
models, err = listModels(false, env.client, true, false, "")
833+
require.NoError(t, err)
834+
require.Empty(t, strings.TrimSpace(models), "Model should be removed after rm with reference: %s", tc.ref)
835+
836+
t.Logf("✓ Successfully removed model using reference: %s", tc.ref)
837+
})
838+
}
839+
})
840+
841+
// Remove multiple models in one command
842+
t.Run("remove multiple models", func(t *testing.T) {
843+
// Create and push two different models
844+
modelRef1 := "ai/rm-multi-1:latest"
845+
modelID1, _, _, _ := createAndPushTestModel(t, env.registryURL, modelRef1, 2048)
846+
modelRef2 := "ai/rm-multi-2:latest"
847+
modelID2, _, _, _ := createAndPushTestModel(t, env.registryURL, modelRef2, 2048)
848+
849+
// Pull both models
850+
t.Logf("Pulling first model: rm-multi-1")
851+
err := pullModel(newPullCmd(), env.client, "rm-multi-1", true)
852+
require.NoError(t, err, "Failed to pull first model")
853+
854+
t.Logf("Pulling second model: rm-multi-2")
855+
err = pullModel(newPullCmd(), env.client, "rm-multi-2", true)
856+
require.NoError(t, err, "Failed to pull second model")
857+
858+
// Verify both models exist
859+
models, err := listModels(false, env.client, false, false, "")
860+
require.NoError(t, err)
861+
require.Contains(t, models, modelID1[7:19], "First model should exist")
862+
require.Contains(t, models, modelID2[7:19], "Second model should exist")
863+
864+
// Remove both models in one command
865+
t.Logf("Removing both models: rm-multi-1 and rm-multi-2")
866+
_, err = env.client.Remove([]string{"rm-multi-1", "rm-multi-2"}, false)
867+
require.NoError(t, err, "Failed to remove multiple models")
868+
869+
// Verify both models are removed
870+
models, err = listModels(false, env.client, true, false, "")
871+
require.NoError(t, err)
872+
require.Empty(t, strings.TrimSpace(models), "All models should be removed")
873+
874+
t.Logf("✓ Successfully removed multiple models in one command")
875+
})
876+
877+
// Tag-specific removal (removing one tag keeps others)
878+
t.Run("remove specific tag keeps other tags", func(t *testing.T) {
879+
// Pull the model
880+
t.Logf("Pulling model: rm-test")
881+
err := pullModel(newPullCmd(), env.client, "rm-test", true)
882+
require.NoError(t, err, "Failed to pull model")
883+
884+
// Add multiple tags to the same model
885+
t.Logf("Adding tags v1, v2, and v3 to the model")
886+
err = tagModel(newTagCmd(), env.client, "rm-test", "rm-test:v1")
887+
require.NoError(t, err, "Failed to create v1 tag")
888+
err = tagModel(newTagCmd(), env.client, "rm-test", "rm-test:v2")
889+
require.NoError(t, err, "Failed to create v2 tag")
890+
err = tagModel(newTagCmd(), env.client, "rm-test", "rm-test:v3")
891+
require.NoError(t, err, "Failed to create v3 tag")
892+
893+
// Verify all tags exist
894+
model, err := env.client.Inspect(modelID, false)
895+
require.NoError(t, err)
896+
require.GreaterOrEqual(t, len(model.Tags), 4, "Model should have at least 4 tags")
897+
t.Logf("Model has %d tags: %v", len(model.Tags), model.Tags)
898+
899+
// Remove one specific tag (v1)
900+
t.Logf("Removing only the v1 tag")
901+
err = removeModel(env.client, "rm-test:v1", false)
902+
require.NoError(t, err, "Failed to remove v1 tag")
903+
904+
// Verify the model still exists with remaining tags
905+
model, err = env.client.Inspect(modelID, false)
906+
require.NoError(t, err, "Model should still exist after removing one tag")
907+
908+
// Verify v1 tag is gone but others remain
909+
tagFound := false
910+
for _, tag := range model.Tags {
911+
if tag == "rm-test:v1" || tag == "ai/rm-test:v1" {
912+
tagFound = true
913+
break
914+
}
915+
}
916+
require.False(t, tagFound, "v1 tag should be removed")
917+
918+
// Verify other tags still exist
919+
v2Found := false
920+
v3Found := false
921+
for _, tag := range model.Tags {
922+
if strings.Contains(tag, ":v2") {
923+
v2Found = true
924+
}
925+
if strings.Contains(tag, ":v3") {
926+
v3Found = true
927+
}
928+
}
929+
require.True(t, v2Found, "v2 tag should still exist")
930+
require.True(t, v3Found, "v3 tag should still exist")
931+
932+
t.Logf("✓ Successfully removed specific tag while keeping others")
933+
934+
// Cleanup: remove the entire model (force=true since multiple tags remain)
935+
err = removeModel(env.client, modelID, true)
936+
require.NoError(t, err, "Failed to cleanup model")
937+
})
938+
939+
// Model ID removal removes all tags
940+
t.Run("remove by model ID removes all tags", func(t *testing.T) {
941+
// Pull the model
942+
t.Logf("Pulling model: rm-test")
943+
err := pullModel(newPullCmd(), env.client, "rm-test", true)
944+
require.NoError(t, err, "Failed to pull model")
945+
946+
// Add multiple tags
947+
t.Logf("Adding multiple tags to the model")
948+
err = tagModel(newTagCmd(), env.client, "rm-test", "rm-test:tag1")
949+
require.NoError(t, err, "Failed to create tag1")
950+
err = tagModel(newTagCmd(), env.client, "rm-test", "rm-test:tag2")
951+
require.NoError(t, err, "Failed to create tag2")
952+
err = tagModel(newTagCmd(), env.client, "rm-test", "rm-test:tag3")
953+
require.NoError(t, err, "Failed to create tag3")
954+
955+
// Verify tags exist
956+
model, err := env.client.Inspect(modelID, false)
957+
require.NoError(t, err)
958+
require.GreaterOrEqual(t, len(model.Tags), 4, "Model should have multiple tags")
959+
t.Logf("Model has %d tags before ID removal", len(model.Tags))
960+
961+
// Remove by model ID (should remove entire model and all tags)
962+
t.Logf("Removing by model ID: %s", modelID)
963+
err = removeModel(env.client, modelID, false)
964+
if !strings.Contains(err.Error(), "(must be forced) due to multiple tag references") {
965+
t.Fatalf("Expected error about multiple tag references when removing by ID without force, got: %v", err)
966+
}
967+
require.Error(t, err, "(must be forced) due to multiple tag references")
968+
})
969+
970+
// Force flag behavior
971+
t.Run("force flag", func(t *testing.T) {
972+
// Pull the model
973+
t.Logf("Pulling model: rm-test")
974+
err := pullModel(newPullCmd(), env.client, "rm-test", true)
975+
require.NoError(t, err, "Failed to pull model")
976+
977+
// Test removal with force flag
978+
t.Logf("Removing model with force flag")
979+
err = removeModel(env.client, modelID, true)
980+
require.NoError(t, err, "Failed to remove with force flag")
981+
982+
// Verify model is removed
983+
models, err := listModels(false, env.client, true, false, "")
984+
require.NoError(t, err)
985+
require.Empty(t, strings.TrimSpace(models), "Model should be removed with force flag")
986+
987+
t.Logf("✓ Successfully removed model with force flag")
988+
})
989+
990+
// Error cases
991+
t.Run("error cases", func(t *testing.T) {
992+
t.Run("remove non-existent model", func(t *testing.T) {
993+
err := removeModel(env.client, "non-existent-model:v1", false)
994+
require.Error(t, err, "Should fail when removing non-existent model")
995+
t.Logf("✓ Correctly failed to remove non-existent model: %v", err)
996+
})
997+
998+
t.Run("force remove non-existent model", func(t *testing.T) {
999+
err := removeModel(env.client, "non-existent-model:v1", true)
1000+
require.Error(t, err, "Should fail when removing non-existent model")
1001+
t.Logf("✓ Correctly failed to remove non-existent model: %v", err)
1002+
})
1003+
1004+
t.Run("remove with empty reference", func(t *testing.T) {
1005+
_, err := env.client.Remove([]string{""}, false)
1006+
require.Error(t, err, "Should fail with empty reference")
1007+
t.Logf("✓ Correctly failed to remove with empty reference: %v", err)
1008+
})
1009+
1010+
t.Run("remove with invalid reference", func(t *testing.T) {
1011+
err := removeModel(env.client, "invalid:reference:format", false)
1012+
require.Error(t, err, "Should fail with invalid reference format")
1013+
t.Logf("✓ Correctly failed to remove with invalid reference: %v", err)
1014+
})
1015+
})
1016+
}

‎cmd/cli/desktop/desktop.go‎

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,8 +489,26 @@ func (c *Client) Remove(modelArgs []string, force bool) (string, error) {
489489
modelRemoved := ""
490490
for _, model := range modelArgs {
491491
model = normalizeHuggingFaceModelName(model)
492-
// Check if not a model ID passed as parameter.
493-
if !strings.Contains(model, "/") {
492+
493+
// Handle digest references (model@sha256:...)
494+
// These need to be normalized to include default org if missing
495+
if strings.Contains(model, "@") && !strings.Contains(model, "/") {
496+
// Split on @ to get repository and digest
497+
parts := strings.SplitN(model, "@", 2)
498+
if len(parts) == 2 {
499+
repo := parts[0]
500+
digest := parts[1]
501+
// Add default org if the repository doesn't contain a slash
502+
if !strings.Contains(repo, "/") {
503+
model = fmt.Sprintf("ai/%s@%s", repo, digest)
504+
}
505+
}
506+
}
507+
508+
// Only expand simple names without tags or digests to model IDs
509+
// Tagged references (model:tag) and digest references (model@sha256:...)
510+
// should be passed as-is to allow tag-specific operations
511+
if !strings.Contains(model, "/") && !strings.Contains(model, ":") && !strings.Contains(model, "@") {
494512
if expanded, err := c.fullModelID(model); err == nil {
495513
model = expanded
496514
}

‎pkg/distribution/distribution/client.go‎

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"io"
88
"net/http"
99
"slices"
10+
"strings"
1011

1112
"github.com/docker/model-runner/pkg/internal/utils"
1213
"github.com/sirupsen/logrus"
@@ -309,7 +310,11 @@ func (c *Client) DeleteModel(reference string, force bool) (*DeleteModelResponse
309310
if err != nil {
310311
return &DeleteModelResponse{}, fmt.Errorf("getting model ID: %w", err)
311312
}
312-
isTag := id != reference
313+
314+
// Check if this is a digest reference (contains @)
315+
// Digest references like "name@sha256:..." should be treated as ID references, not tags
316+
isDigestReference := strings.Contains(reference, "@")
317+
isTag := id != reference && !isDigestReference
313318

314319
resp := DeleteModelResponse{}
315320

0 commit comments

Comments
 (0)