Skip to main content
1 of 5
Alexxino
  • 928
  • 1
  • 4
  • 22

Trouble handling CiphertextForRecipient after KMS GenerateDataKey

I'm building a route for my go (gin) server that generates a Data Encryption Key (DEK), following the principle of zero-trust (the backend never sees the plaintext.

Right now, the client generates a pair of keys and shares the public one to the server (that also acts as a proxy between the client and the enclave).

The enclave also generates a pair of keys and then attach the public one to the attestation. However when I try to generate a DEK, the lengths of CiphertextBlob and CiphertextForRecipient are:

Generated DEK 0: CiphertextBlob len=184, CiphertextForRecipient len=493

They are both []byte, so I think they are already decoded from B64. But shouldn't CiphertextForRecipient be of length 256 to be decrypted using the generated key?

The base implementation of the enclave is as follows (using github.com/hf/nsm):

package main

// ...

var (
    nsmSession        *nsm.Session
    nsmPrivateKey     *rsa.PrivateKey
    nsmPublicKeyBytes []byte
    attestation utils.Attestation = utils.Attestation{Mu: sync.RWMutex{}}
)

func main() {
    log.Println("Starting enclave service...")
    if err := initiateNSM(); err != nil {
        log.Fatalf("Failed to initiate NSM session: %v", err)
    }
    port := getEnvUint32("ENCLAVE_PORT", 5000)
    listener, err := vsock.Listen(port, nil)
    if err != nil {
        log.Fatalf("Failed to start listener: %v", err)
    }
    defer listener.Close()

    // ...

    for {
        conn, err := listener.Accept()
        if err != nil {
            log.Printf("Failed to accept connection: %v", err)
            continue
        }

        log.Printf("New connection accepted from %s", conn.RemoteAddr())
        go handleConnection(conn)
    }
}

func initiateNSM() (err error) {
    // open NSM session
    nsmSession, err = nsm.OpenDefaultSession()
    if err != nil {
        return fmt.Errorf("failed to open NSM session: %v", err)
    }

    // generate RSA key pair for NSM session, nsmSession is used here as a crypto/rand.Reader
    nsmPrivateKey, err = rsa.GenerateKey(nsmSession, 2048)
    if err != nil {
        return fmt.Errorf("failed to generate RSA key: %v", err)
    }

    // extract public key in DER format
    nsmPublicKeyBytes, err = x509.MarshalPKIXPublicKey(&nsmPrivateKey.PublicKey)
    if err != nil {
        return fmt.Errorf("failed to marshal RSA public key: %v", err)
    }

    log.Printf(
        "RSA public key fingerprint: %x",
        sha256.Sum256(nsmPublicKeyBytes),
    )

    log.Println("NSM session initiated")

    return refreshAttestation()
}

func refreshAttestation() error {
    att, err := utils.RequestAttestation(nsmSession, nsmPublicKeyBytes)
    if err != nil {
        return err
    }

    attestation.Mu.Lock()
    attestation.Doc = append([]byte(nil), att...)
    attestation.Mu.Unlock()

    log.Printf("Attestation document refreshed: len=%d", len(attestation.Doc))
    return nil
}

func startAttestationRefresher(ctx context.Context) {
    ticker := time.NewTicker(4 * time.Minute)
    go func() {
        defer ticker.Stop()
        for {
            select {
            case <-ticker.C:
                if err := refreshAttestation(); err != nil {
                    log.Printf("Failed to refresh attestation document: %v", err)
                }
            case <-ctx.Done():
                return
            }
        }
    }()
}
package utils

type Attestation struct {
    Mu  sync.RWMutex
    Doc []byte
}

func RequestAttestation(nsmSession *nsm.Session, nsmPublicKeyBytes []byte) ([]byte, error) {
    req := request.Attestation{
        PublicKey: nsmPublicKeyBytes,
        Nonce:     nil,
        UserData:  nil,
    }
    res, err := nsmSession.Send(&req)
    if err != nil {
        return nil, fmt.Errorf("failed to get attestation document: %v", err)
    }
    if res.Attestation == nil || len(res.Attestation.Document) == 0 {
        return nil, fmt.Errorf("received empty attestation document from NSM")
    }
    return res.Attestation.Document, nil
}

This is the server/proxy route:

func (p *VsockProxy) HandleSessionGenerateDEK(c *gin.Context) {
    // Validate request (SessionGenerateDEKRequest)
    var generateDEKReq enclaveproto.SessionGenerateDEKRequest
    if err := c.BindJSON(&generateDEKReq); err != nil {
        c.JSON(400, gin.H{"error": fmt.Sprintf("Invalid session unwrap request: %v", err)})
        return
    }
    if generateDEKReq.Count <= 0 || generateDEKReq.Count > 100 {
        c.JSON(400, gin.H{"error": "Count must be between 1 and 100"})
        return
    }
    payloadBytes, err := json.Marshal(generateDEKReq)
    if err != nil {
        c.JSON(500, gin.H{"error": fmt.Sprintf("Failed to marshal session unwrap request: %v", err)})
        return
    }

    // Request attestation document from enclave (GetAttestationRequest)
    attReq := enclaveproto.Request{
        Type: "get_attestation",
    }
    resAttBytes, err := p.sendToEnclave(attReq, 5*time.Second)
    if err != nil {
        c.JSON(500, gin.H{"error": err.Error()})
        return
    }
    var resAtt enclaveproto.Response[enclaveproto.GetAttestationResponse]
    if err := json.Unmarshal(resAttBytes, &resAtt); err != nil {
        c.JSON(500, gin.H{"error": fmt.Sprintf("Failed to unmarshal get attestation response: %v", err)})
        return
    }
    att, err := base64.StdEncoding.DecodeString(resAtt.Data.Attestation)
    if err != nil {
        c.JSON(500, gin.H{"error": fmt.Sprintf("Failed to decode attestation document: %v", err)})
        return
    }
    if !resAtt.Success {
        c.JSON(500, gin.H{"error": resAtt.Error})
        return
    }
    if len(att) == 0 {
        c.JSON(500, gin.H{"error": "Empty attestation document"})
        return
    }

    // Generate data keys with KMS using attestation document
    input := &kms.GenerateDataKeyInput{
        KeyId:   &generateDEKReq.KeyId,
        KeySpec: "AES_256",
        Recipient: &kmsTypes.RecipientInfo{
            AttestationDocument:    att,
            KeyEncryptionAlgorithm: kmsTypes.KeyEncryptionMechanismRsaesOaepSha256,
        },
    }
    deks := make([]enclaveproto.EnclaveDEKToPrepare, generateDEKReq.Count)
    for i := 0; i < generateDEKReq.Count; i++ {
        out, err := p.KMSClient.GenerateDataKey(c.Request.Context(), input)
        if err != nil {
            c.JSON(500, gin.H{"error": fmt.Sprintf("Failed to generate data key: %v", err)})
            return
        }
        if out.CiphertextForRecipient == nil {
            c.JSON(500, gin.H{"error": "No CiphertextForRecipient in KMS response"})
            return
        }
        deks[i] = enclaveproto.EnclaveDEKToPrepare{
            CiphertextBlob:         base64.StdEncoding.EncodeToString(out.CiphertextBlob),
            CiphertextForRecipient: base64.StdEncoding.EncodeToString(out.CiphertextForRecipient),
        }
        log.Printf("Generated DEK %d: CiphertextBlob len=%d, CiphertextForRecipient len=%d", i, len(out.CiphertextBlob), len(out.CiphertextForRecipient))
    }

    // Prepare DEK in enclave (EnclavePrepareDEKRequest)
    prepareDEKReq := enclaveproto.EnclavePrepareDEKRequest{
        DEKs:      deks,
        SessionId: generateDEKReq.SessionId,
    }
    payloadBytes, err = json.Marshal(prepareDEKReq)
    if err != nil {
        c.JSON(500, gin.H{"error": fmt.Sprintf("Failed to marshal session unwrap request: %v", err)})
        return
    }
    req := enclaveproto.Request{
        Type:    "session_prepare_dek",
        Payload: json.RawMessage(payloadBytes),
    }
    resBytes, err := p.sendToEnclave(req, 30*time.Second)
    if err != nil {
        c.JSON(500, gin.H{"error": err.Error()})
        return
    }
    var prepRes enclaveproto.Response[enclaveproto.EnclavePrepareDEKResponse]
    if err := json.Unmarshal(resBytes, &prepRes); err != nil {
        c.JSON(500, gin.H{"error": fmt.Sprintf("Failed to unmarshal session unwrap response: %v", err)})
        return
    }
    if !prepRes.Success {
        c.JSON(500, gin.H{"error": prepRes.Error})
        return
    }

    // Build response
    res := enclaveproto.Response[enclaveproto.SessionGenerateDEKResponse]{
        Success: true,
        Data: enclaveproto.SessionGenerateDEKResponse{
            DEKs: prepRes.Data.DEKs,
        },
    }

    c.JSON(200, res)
}

And this is the enclave handler that is called:

func HandleSessionPrepareDEK(encoder *json.Encoder, payload json.RawMessage, nsmSession *nsm.Session, nsmPrivKey *rsa.PrivateKey) {
    var req enclaveproto.EnclavePrepareDEKRequest
    if err := json.Unmarshal(payload, &req); err != nil {
        utils.SendError(encoder, fmt.Sprintf("Invalid session generate DEK request: %v", err))
        return
    }

    sess, ok := session.GetSession(req.SessionId)
    if !ok {
        utils.SendError(encoder, "Invalid or expired session ID")
        return
    }

    aead, err := chacha20poly1305.NewX(sess.Key)
    if err != nil {
        utils.SendError(encoder, fmt.Sprintf("Failed to create AEAD cipher: %v", err))
        return
    }

    results := make([]enclaveproto.GeneratedDEK, 0, len(req.DEKs))

    for _, dek := range req.DEKs {
        ciphertextForRecipient, err := base64.StdEncoding.DecodeString(dek.CiphertextForRecipient)
        if err != nil {
            utils.SendError(encoder, fmt.Sprintf("Failed to decode CiphertextForRecipient: %v", err))
            return
        }
        ciphertextBlob, err := base64.StdEncoding.DecodeString(dek.CiphertextBlob)
        if err != nil {
            utils.SendError(encoder, fmt.Sprintf("Failed to decode CiphertextBlob: %v", err))
            return
        }
        log.Printf(
            "CiphertextForRecipient len=%d CiphertextBlob len=%d",
            len(ciphertextForRecipient),
            len(ciphertextBlob),
        )
        if len(ciphertextForRecipient) != 256 { // <--- THIS THROWS HERE!!!
            log.Printf("INVALID RSA CIPHERTEXT LENGTH: %d", len(ciphertextForRecipient))
            utils.SendError(encoder, "invalid CiphertextForRecipient length")
            return
        }
        plainDEK, err := rsa.DecryptOAEP(sha256.New(), nsmSession, nsmPrivKey, ciphertextForRecipient, nil)
        if err != nil {
            utils.SendError(encoder, fmt.Sprintf("Failed to decrypt data key: %v", err))
            return
        }

        // nonce for session encryption
        nonce := make([]byte, chacha20poly1305.NonceSizeX) // 24 bytes
        if _, err := nsmSession.Read(nonce); err != nil {
            utils.SendError(encoder, fmt.Sprintf("Failed to read nonce: %v", err))
            return
        }

        // seal DEK with session key
        sealedDEK := aead.Seal(nil, nonce, plainDEK, nil)

        utils.Zero(plainDEK)

        results = append(results, enclaveproto.GeneratedDEK{
            SealedDEK:       base64.StdEncoding.EncodeToString(sealedDEK),
            KmsEncryptedDEK: dek.CiphertextBlob,
            Nonce:           base64.StdEncoding.EncodeToString(nonce),
        })
    }

    res := enclaveproto.SessionGenerateDEKResponse{
        DEKs: results,
    }

    utils.SendResponse(encoder, enclaveproto.Response[enclaveproto.SessionGenerateDEKResponse]{
        Success: true,
        Data:    res,
    })
}

Thanks in advance for any help.

Alexxino
  • 928
  • 1
  • 4
  • 22