1

I'm building a route for my go (gin) server that generates a Data Encryption Key (DEK), following the principle of zero-trust (the backend never sees the plaintext).

Right now, the client generates a pair of keys and shares the public one to the server (that also acts as a proxy between the client and the enclave).

The enclave also generates a pair of keys and then attach the public one to the attestation. However when I try to generate a DEK, the lengths of CiphertextBlob and CiphertextForRecipient are:

Generated DEK 0: CiphertextBlob len=184, CiphertextForRecipient len=493

They are both []byte, so I think they are already decoded from B64. But shouldn't CiphertextForRecipient be of length 256 to be decrypted using the generated key?

The final flow should look something like:

  • Client calls Server’s GenerateDEKs route
  • Server requests Enclave’s attestation
  • Server calls KMS with Enclave’s attestation as RecipientInfo
  • KMS returns DEKs
  • Server sends DEKs to Enclave
  • Enclave decrypts DEKs with its private key, re-encrypts them with the client’s ephemeral public key, sends them back to Server
  • Server sends DEKs to Client
  • Clients decrypts newly-generated DEKs locally with its private key and uses them for encrypting objects

How exactly am I supposed to decrypt CiphertextForRecipient?

I found this answer for a similar problem: How to decrypt the CiphertextForRecipient using the private key in the enclave?

But I really don't know how to do the same thing in Go.

The base implementation of the enclave is as follows (using github.com/hf/nsm):

package main

// ...

var (
    nsmSession        *nsm.Session
    nsmPrivateKey     *rsa.PrivateKey
    nsmPublicKeyBytes []byte
    attestation utils.Attestation = utils.Attestation{Mu: sync.RWMutex{}}
)

func main() {
    log.Println("Starting enclave service...")
    if err := initiateNSM(); err != nil {
        log.Fatalf("Failed to initiate NSM session: %v", err)
    }
    port := getEnvUint32("ENCLAVE_PORT", 5000)
    listener, err := vsock.Listen(port, nil)
    if err != nil {
        log.Fatalf("Failed to start listener: %v", err)
    }
    defer listener.Close()

    // ...

    for {
        conn, err := listener.Accept()
        if err != nil {
            log.Printf("Failed to accept connection: %v", err)
            continue
        }

        log.Printf("New connection accepted from %s", conn.RemoteAddr())
        go handleConnection(conn)
    }
}

func initiateNSM() (err error) {
    // open NSM session
    nsmSession, err = nsm.OpenDefaultSession()
    if err != nil {
        return fmt.Errorf("failed to open NSM session: %v", err)
    }

    // generate RSA key pair for NSM session, nsmSession is used here as a crypto/rand.Reader
    nsmPrivateKey, err = rsa.GenerateKey(nsmSession, 2048)
    if err != nil {
        return fmt.Errorf("failed to generate RSA key: %v", err)
    }

    // extract public key in DER format
    nsmPublicKeyBytes, err = x509.MarshalPKIXPublicKey(&nsmPrivateKey.PublicKey)
    if err != nil {
        return fmt.Errorf("failed to marshal RSA public key: %v", err)
    }

    log.Printf(
        "RSA public key fingerprint: %x",
        sha256.Sum256(nsmPublicKeyBytes),
    )

    log.Println("NSM session initiated")

    return refreshAttestation()
}

func refreshAttestation() error {
    att, err := utils.RequestAttestation(nsmSession, nsmPublicKeyBytes)
    if err != nil {
        return err
    }

    attestation.Mu.Lock()
    attestation.Doc = append([]byte(nil), att...)
    attestation.Mu.Unlock()

    log.Printf("Attestation document refreshed: len=%d", len(attestation.Doc))
    return nil
}

func startAttestationRefresher(ctx context.Context) {
    ticker := time.NewTicker(4 * time.Minute)
    go func() {
        defer ticker.Stop()
        for {
            select {
            case <-ticker.C:
                if err := refreshAttestation(); err != nil {
                    log.Printf("Failed to refresh attestation document: %v", err)
                }
            case <-ctx.Done():
                return
            }
        }
    }()
}
package utils

type Attestation struct {
    Mu  sync.RWMutex
    Doc []byte
}

func RequestAttestation(nsmSession *nsm.Session, nsmPublicKeyBytes []byte) ([]byte, error) {
    req := request.Attestation{
        PublicKey: nsmPublicKeyBytes,
        Nonce:     nil,
        UserData:  nil,
    }
    res, err := nsmSession.Send(&req)
    if err != nil {
        return nil, fmt.Errorf("failed to get attestation document: %v", err)
    }
    if res.Attestation == nil || len(res.Attestation.Document) == 0 {
        return nil, fmt.Errorf("received empty attestation document from NSM")
    }
    return res.Attestation.Document, nil
}

This is the server/proxy route:

func (p *VsockProxy) HandleSessionGenerateDEK(c *gin.Context) {
    // Validate request (SessionGenerateDEKRequest)
    var generateDEKReq enclaveproto.SessionGenerateDEKRequest
    if err := c.BindJSON(&generateDEKReq); err != nil {
        c.JSON(400, gin.H{"error": fmt.Sprintf("Invalid session unwrap request: %v", err)})
        return
    }
    if generateDEKReq.Count <= 0 || generateDEKReq.Count > 100 {
        c.JSON(400, gin.H{"error": "Count must be between 1 and 100"})
        return
    }
    payloadBytes, err := json.Marshal(generateDEKReq)
    if err != nil {
        c.JSON(500, gin.H{"error": fmt.Sprintf("Failed to marshal session unwrap request: %v", err)})
        return
    }

    // Request attestation document from enclave (GetAttestationRequest)
    attReq := enclaveproto.Request{
        Type: "get_attestation",
    }
    resAttBytes, err := p.sendToEnclave(attReq, 5*time.Second)
    if err != nil {
        c.JSON(500, gin.H{"error": err.Error()})
        return
    }
    var resAtt enclaveproto.Response[enclaveproto.GetAttestationResponse]
    if err := json.Unmarshal(resAttBytes, &resAtt); err != nil {
        c.JSON(500, gin.H{"error": fmt.Sprintf("Failed to unmarshal get attestation response: %v", err)})
        return
    }
    att, err := base64.StdEncoding.DecodeString(resAtt.Data.Attestation)
    if err != nil {
        c.JSON(500, gin.H{"error": fmt.Sprintf("Failed to decode attestation document: %v", err)})
        return
    }
    if !resAtt.Success {
        c.JSON(500, gin.H{"error": resAtt.Error})
        return
    }
    if len(att) == 0 {
        c.JSON(500, gin.H{"error": "Empty attestation document"})
        return
    }

    // Generate data keys with KMS using attestation document
    input := &kms.GenerateDataKeyInput{
        KeyId:   &generateDEKReq.KeyId,
        KeySpec: "AES_256",
        Recipient: &kmsTypes.RecipientInfo{
            AttestationDocument:    att,
            KeyEncryptionAlgorithm: kmsTypes.KeyEncryptionMechanismRsaesOaepSha256,
        },
    }
    deks := make([]enclaveproto.EnclaveDEKToPrepare, generateDEKReq.Count)
    for i := 0; i < generateDEKReq.Count; i++ {
        out, err := p.KMSClient.GenerateDataKey(c.Request.Context(), input)
        if err != nil {
            c.JSON(500, gin.H{"error": fmt.Sprintf("Failed to generate data key: %v", err)})
            return
        }
        if out.CiphertextForRecipient == nil {
            c.JSON(500, gin.H{"error": "No CiphertextForRecipient in KMS response"})
            return
        }
        deks[i] = enclaveproto.EnclaveDEKToPrepare{
            CiphertextBlob:         base64.StdEncoding.EncodeToString(out.CiphertextBlob),
            CiphertextForRecipient: base64.StdEncoding.EncodeToString(out.CiphertextForRecipient),
        }
        log.Printf("Generated DEK %d: CiphertextBlob len=%d, CiphertextForRecipient len=%d", i, len(out.CiphertextBlob), len(out.CiphertextForRecipient))
    }

    // Prepare DEK in enclave (EnclavePrepareDEKRequest)
    prepareDEKReq := enclaveproto.EnclavePrepareDEKRequest{
        DEKs:      deks,
        SessionId: generateDEKReq.SessionId,
    }
    payloadBytes, err = json.Marshal(prepareDEKReq)
    if err != nil {
        c.JSON(500, gin.H{"error": fmt.Sprintf("Failed to marshal session unwrap request: %v", err)})
        return
    }
    req := enclaveproto.Request{
        Type:    "session_prepare_dek",
        Payload: json.RawMessage(payloadBytes),
    }
    resBytes, err := p.sendToEnclave(req, 30*time.Second)
    if err != nil {
        c.JSON(500, gin.H{"error": err.Error()})
        return
    }
    var prepRes enclaveproto.Response[enclaveproto.EnclavePrepareDEKResponse]
    if err := json.Unmarshal(resBytes, &prepRes); err != nil {
        c.JSON(500, gin.H{"error": fmt.Sprintf("Failed to unmarshal session unwrap response: %v", err)})
        return
    }
    if !prepRes.Success {
        c.JSON(500, gin.H{"error": prepRes.Error})
        return
    }

    // Build response
    res := enclaveproto.Response[enclaveproto.SessionGenerateDEKResponse]{
        Success: true,
        Data: enclaveproto.SessionGenerateDEKResponse{
            DEKs: prepRes.Data.DEKs,
        },
    }

    c.JSON(200, res)
}

And this is the enclave handler that is called:

func HandleSessionPrepareDEK(encoder *json.Encoder, payload json.RawMessage, nsmSession *nsm.Session, nsmPrivKey *rsa.PrivateKey) {
    var req enclaveproto.EnclavePrepareDEKRequest
    if err := json.Unmarshal(payload, &req); err != nil {
        utils.SendError(encoder, fmt.Sprintf("Invalid session generate DEK request: %v", err))
        return
    }

    sess, ok := session.GetSession(req.SessionId)
    if !ok {
        utils.SendError(encoder, "Invalid or expired session ID")
        return
    }

    aead, err := chacha20poly1305.NewX(sess.Key)
    if err != nil {
        utils.SendError(encoder, fmt.Sprintf("Failed to create AEAD cipher: %v", err))
        return
    }

    results := make([]enclaveproto.GeneratedDEK, 0, len(req.DEKs))

    for _, dek := range req.DEKs {
        ciphertextForRecipient, err := base64.StdEncoding.DecodeString(dek.CiphertextForRecipient)
        if err != nil {
            utils.SendError(encoder, fmt.Sprintf("Failed to decode CiphertextForRecipient: %v", err))
            return
        }
        ciphertextBlob, err := base64.StdEncoding.DecodeString(dek.CiphertextBlob)
        if err != nil {
            utils.SendError(encoder, fmt.Sprintf("Failed to decode CiphertextBlob: %v", err))
            return
        }
        log.Printf(
            "CiphertextForRecipient len=%d CiphertextBlob len=%d",
            len(ciphertextForRecipient),
            len(ciphertextBlob),
        )
        if len(ciphertextForRecipient) != 256 { // <--- THIS THROWS HERE!!!
            log.Printf("INVALID RSA CIPHERTEXT LENGTH: %d", len(ciphertextForRecipient))
            utils.SendError(encoder, "invalid CiphertextForRecipient length")
            return
        }
        plainDEK, err := rsa.DecryptOAEP(sha256.New(), nsmSession, nsmPrivKey, ciphertextForRecipient, nil)
        if err != nil {
            utils.SendError(encoder, fmt.Sprintf("Failed to decrypt data key: %v", err))
            return
        }

        // nonce for session encryption
        nonce := make([]byte, chacha20poly1305.NonceSizeX) // 24 bytes
        if _, err := nsmSession.Read(nonce); err != nil {
            utils.SendError(encoder, fmt.Sprintf("Failed to read nonce: %v", err))
            return
        }

        // seal DEK with session key
        sealedDEK := aead.Seal(nil, nonce, plainDEK, nil)

        utils.Zero(plainDEK)

        results = append(results, enclaveproto.GeneratedDEK{
            SealedDEK:       base64.StdEncoding.EncodeToString(sealedDEK),
            KmsEncryptedDEK: dek.CiphertextBlob,
            Nonce:           base64.StdEncoding.EncodeToString(nonce),
        })
    }

    res := enclaveproto.SessionGenerateDEKResponse{
        DEKs: results,
    }

    utils.SendResponse(encoder, enclaveproto.Response[enclaveproto.SessionGenerateDEKResponse]{
        Success: true,
        Data:    res,
    })
}

Thanks in advance for any help.

0

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.