I'm building a route for my go (gin) server that generates a Data Encryption Key (DEK), following the principle of zero-trust (the backend never sees the plaintext).
Right now, the client generates a pair of keys and shares the public one to the server (that also acts as a proxy between the client and the enclave).
The enclave also generates a pair of keys and then attach the public one to the attestation. However when I try to generate a DEK, the lengths of CiphertextBlob and CiphertextForRecipient are:
Generated DEK 0: CiphertextBlob len=184, CiphertextForRecipient len=493
They are both []byte, so I think they are already decoded from B64. But shouldn't CiphertextForRecipient be of length 256 to be decrypted using the generated key?
The final flow should look something like:
- Client calls Server’s GenerateDEKs route
- Server requests Enclave’s attestation
- Server calls KMS with Enclave’s attestation as RecipientInfo
- KMS returns DEKs
- Server sends DEKs to Enclave
- Enclave decrypts DEKs with its private key, re-encrypts them with the client’s ephemeral public key, sends them back to Server
- Server sends DEKs to Client
- Clients decrypts newly-generated DEKs locally with its private key and uses them for encrypting objects
How exactly am I supposed to decrypt CiphertextForRecipient?
I found this answer for a similar problem: How to decrypt the CiphertextForRecipient using the private key in the enclave?
But I really don't know how to do the same thing in Go.
The base implementation of the enclave is as follows (using github.com/hf/nsm):
package main
// ...
var (
nsmSession *nsm.Session
nsmPrivateKey *rsa.PrivateKey
nsmPublicKeyBytes []byte
attestation utils.Attestation = utils.Attestation{Mu: sync.RWMutex{}}
)
func main() {
log.Println("Starting enclave service...")
if err := initiateNSM(); err != nil {
log.Fatalf("Failed to initiate NSM session: %v", err)
}
port := getEnvUint32("ENCLAVE_PORT", 5000)
listener, err := vsock.Listen(port, nil)
if err != nil {
log.Fatalf("Failed to start listener: %v", err)
}
defer listener.Close()
// ...
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("Failed to accept connection: %v", err)
continue
}
log.Printf("New connection accepted from %s", conn.RemoteAddr())
go handleConnection(conn)
}
}
func initiateNSM() (err error) {
// open NSM session
nsmSession, err = nsm.OpenDefaultSession()
if err != nil {
return fmt.Errorf("failed to open NSM session: %v", err)
}
// generate RSA key pair for NSM session, nsmSession is used here as a crypto/rand.Reader
nsmPrivateKey, err = rsa.GenerateKey(nsmSession, 2048)
if err != nil {
return fmt.Errorf("failed to generate RSA key: %v", err)
}
// extract public key in DER format
nsmPublicKeyBytes, err = x509.MarshalPKIXPublicKey(&nsmPrivateKey.PublicKey)
if err != nil {
return fmt.Errorf("failed to marshal RSA public key: %v", err)
}
log.Printf(
"RSA public key fingerprint: %x",
sha256.Sum256(nsmPublicKeyBytes),
)
log.Println("NSM session initiated")
return refreshAttestation()
}
func refreshAttestation() error {
att, err := utils.RequestAttestation(nsmSession, nsmPublicKeyBytes)
if err != nil {
return err
}
attestation.Mu.Lock()
attestation.Doc = append([]byte(nil), att...)
attestation.Mu.Unlock()
log.Printf("Attestation document refreshed: len=%d", len(attestation.Doc))
return nil
}
func startAttestationRefresher(ctx context.Context) {
ticker := time.NewTicker(4 * time.Minute)
go func() {
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := refreshAttestation(); err != nil {
log.Printf("Failed to refresh attestation document: %v", err)
}
case <-ctx.Done():
return
}
}
}()
}
package utils
type Attestation struct {
Mu sync.RWMutex
Doc []byte
}
func RequestAttestation(nsmSession *nsm.Session, nsmPublicKeyBytes []byte) ([]byte, error) {
req := request.Attestation{
PublicKey: nsmPublicKeyBytes,
Nonce: nil,
UserData: nil,
}
res, err := nsmSession.Send(&req)
if err != nil {
return nil, fmt.Errorf("failed to get attestation document: %v", err)
}
if res.Attestation == nil || len(res.Attestation.Document) == 0 {
return nil, fmt.Errorf("received empty attestation document from NSM")
}
return res.Attestation.Document, nil
}
This is the server/proxy route:
func (p *VsockProxy) HandleSessionGenerateDEK(c *gin.Context) {
// Validate request (SessionGenerateDEKRequest)
var generateDEKReq enclaveproto.SessionGenerateDEKRequest
if err := c.BindJSON(&generateDEKReq); err != nil {
c.JSON(400, gin.H{"error": fmt.Sprintf("Invalid session unwrap request: %v", err)})
return
}
if generateDEKReq.Count <= 0 || generateDEKReq.Count > 100 {
c.JSON(400, gin.H{"error": "Count must be between 1 and 100"})
return
}
payloadBytes, err := json.Marshal(generateDEKReq)
if err != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("Failed to marshal session unwrap request: %v", err)})
return
}
// Request attestation document from enclave (GetAttestationRequest)
attReq := enclaveproto.Request{
Type: "get_attestation",
}
resAttBytes, err := p.sendToEnclave(attReq, 5*time.Second)
if err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
var resAtt enclaveproto.Response[enclaveproto.GetAttestationResponse]
if err := json.Unmarshal(resAttBytes, &resAtt); err != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("Failed to unmarshal get attestation response: %v", err)})
return
}
att, err := base64.StdEncoding.DecodeString(resAtt.Data.Attestation)
if err != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("Failed to decode attestation document: %v", err)})
return
}
if !resAtt.Success {
c.JSON(500, gin.H{"error": resAtt.Error})
return
}
if len(att) == 0 {
c.JSON(500, gin.H{"error": "Empty attestation document"})
return
}
// Generate data keys with KMS using attestation document
input := &kms.GenerateDataKeyInput{
KeyId: &generateDEKReq.KeyId,
KeySpec: "AES_256",
Recipient: &kmsTypes.RecipientInfo{
AttestationDocument: att,
KeyEncryptionAlgorithm: kmsTypes.KeyEncryptionMechanismRsaesOaepSha256,
},
}
deks := make([]enclaveproto.EnclaveDEKToPrepare, generateDEKReq.Count)
for i := 0; i < generateDEKReq.Count; i++ {
out, err := p.KMSClient.GenerateDataKey(c.Request.Context(), input)
if err != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("Failed to generate data key: %v", err)})
return
}
if out.CiphertextForRecipient == nil {
c.JSON(500, gin.H{"error": "No CiphertextForRecipient in KMS response"})
return
}
deks[i] = enclaveproto.EnclaveDEKToPrepare{
CiphertextBlob: base64.StdEncoding.EncodeToString(out.CiphertextBlob),
CiphertextForRecipient: base64.StdEncoding.EncodeToString(out.CiphertextForRecipient),
}
log.Printf("Generated DEK %d: CiphertextBlob len=%d, CiphertextForRecipient len=%d", i, len(out.CiphertextBlob), len(out.CiphertextForRecipient))
}
// Prepare DEK in enclave (EnclavePrepareDEKRequest)
prepareDEKReq := enclaveproto.EnclavePrepareDEKRequest{
DEKs: deks,
SessionId: generateDEKReq.SessionId,
}
payloadBytes, err = json.Marshal(prepareDEKReq)
if err != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("Failed to marshal session unwrap request: %v", err)})
return
}
req := enclaveproto.Request{
Type: "session_prepare_dek",
Payload: json.RawMessage(payloadBytes),
}
resBytes, err := p.sendToEnclave(req, 30*time.Second)
if err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
var prepRes enclaveproto.Response[enclaveproto.EnclavePrepareDEKResponse]
if err := json.Unmarshal(resBytes, &prepRes); err != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("Failed to unmarshal session unwrap response: %v", err)})
return
}
if !prepRes.Success {
c.JSON(500, gin.H{"error": prepRes.Error})
return
}
// Build response
res := enclaveproto.Response[enclaveproto.SessionGenerateDEKResponse]{
Success: true,
Data: enclaveproto.SessionGenerateDEKResponse{
DEKs: prepRes.Data.DEKs,
},
}
c.JSON(200, res)
}
And this is the enclave handler that is called:
func HandleSessionPrepareDEK(encoder *json.Encoder, payload json.RawMessage, nsmSession *nsm.Session, nsmPrivKey *rsa.PrivateKey) {
var req enclaveproto.EnclavePrepareDEKRequest
if err := json.Unmarshal(payload, &req); err != nil {
utils.SendError(encoder, fmt.Sprintf("Invalid session generate DEK request: %v", err))
return
}
sess, ok := session.GetSession(req.SessionId)
if !ok {
utils.SendError(encoder, "Invalid or expired session ID")
return
}
aead, err := chacha20poly1305.NewX(sess.Key)
if err != nil {
utils.SendError(encoder, fmt.Sprintf("Failed to create AEAD cipher: %v", err))
return
}
results := make([]enclaveproto.GeneratedDEK, 0, len(req.DEKs))
for _, dek := range req.DEKs {
ciphertextForRecipient, err := base64.StdEncoding.DecodeString(dek.CiphertextForRecipient)
if err != nil {
utils.SendError(encoder, fmt.Sprintf("Failed to decode CiphertextForRecipient: %v", err))
return
}
ciphertextBlob, err := base64.StdEncoding.DecodeString(dek.CiphertextBlob)
if err != nil {
utils.SendError(encoder, fmt.Sprintf("Failed to decode CiphertextBlob: %v", err))
return
}
log.Printf(
"CiphertextForRecipient len=%d CiphertextBlob len=%d",
len(ciphertextForRecipient),
len(ciphertextBlob),
)
if len(ciphertextForRecipient) != 256 { // <--- THIS THROWS HERE!!!
log.Printf("INVALID RSA CIPHERTEXT LENGTH: %d", len(ciphertextForRecipient))
utils.SendError(encoder, "invalid CiphertextForRecipient length")
return
}
plainDEK, err := rsa.DecryptOAEP(sha256.New(), nsmSession, nsmPrivKey, ciphertextForRecipient, nil)
if err != nil {
utils.SendError(encoder, fmt.Sprintf("Failed to decrypt data key: %v", err))
return
}
// nonce for session encryption
nonce := make([]byte, chacha20poly1305.NonceSizeX) // 24 bytes
if _, err := nsmSession.Read(nonce); err != nil {
utils.SendError(encoder, fmt.Sprintf("Failed to read nonce: %v", err))
return
}
// seal DEK with session key
sealedDEK := aead.Seal(nil, nonce, plainDEK, nil)
utils.Zero(plainDEK)
results = append(results, enclaveproto.GeneratedDEK{
SealedDEK: base64.StdEncoding.EncodeToString(sealedDEK),
KmsEncryptedDEK: dek.CiphertextBlob,
Nonce: base64.StdEncoding.EncodeToString(nonce),
})
}
res := enclaveproto.SessionGenerateDEKResponse{
DEKs: results,
}
utils.SendResponse(encoder, enclaveproto.Response[enclaveproto.SessionGenerateDEKResponse]{
Success: true,
Data: res,
})
}
Thanks in advance for any help.