Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 141 additions & 23 deletions src/wh_client.c
Original file line number Diff line number Diff line change
Expand Up @@ -1430,30 +1430,48 @@ int wh_Client_KeyCacheDmaRequest(whClientContext* c, uint32_t flags,
const void* keyAddr, uint16_t keySz,
uint16_t keyId)
{
int ret;
whMessageKeystore_CacheDmaRequest* req = NULL;
uintptr_t keyAddrPtr = 0;
uint16_t capSz = 0;
int ret = WH_ERROR_OK;
whMessageKeystore_CacheDmaRequest* req = NULL;
uintptr_t keyAddrPtr = 0;
uint16_t capSz = 0;
int keyAddrAcquired = 0;

if (c == NULL || (labelSz > 0 && label == NULL)) {
return WH_ERROR_BADARGS;
}
/* Fail fast if busy: don't acquire a mapping a rejected send would leak. */
if (wh_CommClient_IsRequestPending(c->comm) == 1) {
return WH_ERROR_REQUEST_PENDING;
}

req = (whMessageKeystore_CacheDmaRequest*)wh_CommClient_GetDataPtr(c->comm);
if (req == NULL) {
return WH_ERROR_BADARGS;
}
memset(req, 0, sizeof(*req));
req->id = keyId;
req->flags = flags;
req->labelSz = 0;

/* Set up DMA buffer info */
req->id = keyId;
req->flags = flags;
req->labelSz = 0;
req->key.sz = keySz;
ret = wh_Client_DmaProcessClientAddress(
req->key.addr = 0;

/* Clear the slot up front so a skipped PRE leaves nothing for POST. */
c->dma.asyncCtx.buf.sz = 0;

/* PRE-translate the input key buffer. POST runs in the Response, not here:
* the server reads the buffer between request and response, so an
* in-request POST would free the scratch too early (use-after-free). */
ret = wh_Client_DmaProcessClientAddress(
c, (uintptr_t)keyAddr, (void**)&keyAddrPtr, keySz,
WH_DMA_OPER_CLIENT_READ_PRE, (whDmaFlags){0});
req->key.addr = keyAddrPtr;
if (ret == WH_ERROR_OK) {
keyAddrAcquired = 1;
req->key.addr = (uint64_t)keyAddrPtr;
c->dma.asyncCtx.buf.xformedAddr = keyAddrPtr;
c->dma.asyncCtx.buf.clientAddr = (uintptr_t)keyAddr;
c->dma.asyncCtx.buf.sz = keySz;
c->dma.asyncCtx.buf.postOper = WH_DMA_OPER_CLIENT_READ_POST;
}

/* Copy label if provided, truncate if necessary */
if (labelSz > 0 && label != NULL) {
Expand All @@ -1467,9 +1485,10 @@ int wh_Client_KeyCacheDmaRequest(whClientContext* c, uint32_t flags,
sizeof(*req), (uint8_t*)req);
}

(void)wh_Client_DmaProcessClientAddress(
c, (uintptr_t)keyAddr, (void**)&keyAddrPtr, keySz,
WH_DMA_OPER_CLIENT_READ_POST, (whDmaFlags){0});
if (ret != WH_ERROR_OK && keyAddrAcquired) {
/* SendRequest failed: the Response will not run, so POST now. */
(void)wh_Client_DmaAsyncPost(c, &c->dma.asyncCtx.buf);
}
return ret;
}

Expand All @@ -1492,6 +1511,9 @@ int wh_Client_KeyCacheDmaResponse(whClientContext* c, uint16_t* keyId)
}

ret = wh_Client_RecvResponse(c, &group, &action, &size, (uint8_t*)resp);
if (ret == WH_ERROR_NOTREADY) {
return ret;
}

if (ret == 0) {
/* Validate response */
Expand All @@ -1510,6 +1532,15 @@ int wh_Client_KeyCacheDmaResponse(whClientContext* c, uint16_t* keyId)
}
}
}

/* POST cleanup: release the mapping once the server has read it. Surface a
* POST failure if the operation otherwise succeeded. */
{
int postRc = wh_Client_DmaAsyncPost(c, &c->dma.asyncCtx.buf);
if (ret == WH_ERROR_OK) {
ret = postRc;
}
}
return ret;
}

Expand All @@ -1531,23 +1562,56 @@ int wh_Client_KeyCacheDma(whClientContext* c, uint32_t flags, uint8_t* label,
int wh_Client_KeyExportDmaRequest(whClientContext* c, uint16_t keyId,
const void* keyAddr, uint16_t keySz)
{
whMessageKeystore_ExportDmaRequest* req = NULL;
whMessageKeystore_ExportDmaRequest* req = NULL;
uintptr_t keyAddrPtr = 0;
int ret = WH_ERROR_OK;
int keyAddrAcquired = 0;

if (c == NULL || keyId == WH_KEYID_ERASED) {
return WH_ERROR_BADARGS;
}
/* Fail fast if busy: don't acquire a mapping a rejected send would leak. */
if (wh_CommClient_IsRequestPending(c->comm) == 1) {
return WH_ERROR_REQUEST_PENDING;
}

req =
(whMessageKeystore_ExportDmaRequest*)wh_CommClient_GetDataPtr(c->comm);
if (req == NULL) {
return WH_ERROR_BADARGS;
}

req->id = keyId;
req->key.addr = (uint64_t)((uintptr_t)keyAddr);
req->key.addr = 0;
req->key.sz = keySz;

return wh_Client_SendRequest(c, WH_MESSAGE_GROUP_KEY, WH_KEY_EXPORT_DMA,
sizeof(*req), (uint8_t*)req);
/* Clear the slot up front so a skipped PRE leaves nothing for POST. */
c->dma.asyncCtx.buf.sz = 0;

/* PRE-translate the output key buffer; the server fills it and the
* Response POST copies the result back and releases it. */
ret = wh_Client_DmaProcessClientAddress(
c, (uintptr_t)keyAddr, (void**)&keyAddrPtr, keySz,
WH_DMA_OPER_CLIENT_WRITE_PRE, (whDmaFlags){0});
if (ret == WH_ERROR_OK) {
keyAddrAcquired = 1;
req->key.addr = (uint64_t)keyAddrPtr;
c->dma.asyncCtx.buf.xformedAddr = keyAddrPtr;
c->dma.asyncCtx.buf.clientAddr = (uintptr_t)keyAddr;
c->dma.asyncCtx.buf.sz = keySz;
c->dma.asyncCtx.buf.postOper = WH_DMA_OPER_CLIENT_WRITE_POST;
}

if (ret == WH_ERROR_OK) {
ret = wh_Client_SendRequest(c, WH_MESSAGE_GROUP_KEY, WH_KEY_EXPORT_DMA,
sizeof(*req), (uint8_t*)req);
}

if (ret != WH_ERROR_OK && keyAddrAcquired) {
/* SendRequest failed: the Response will not run, so POST now. */
(void)wh_Client_DmaAsyncPost(c, &c->dma.asyncCtx.buf);
}
return ret;
}

int wh_Client_KeyExportDmaResponse(whClientContext* c, uint8_t* label,
Expand All @@ -1571,6 +1635,9 @@ int wh_Client_KeyExportDmaResponse(whClientContext* c, uint8_t* label,

rc = wh_Client_RecvResponse(c, &resp_group, &resp_action, &resp_size,
(uint8_t*)resp);
if (rc == WH_ERROR_NOTREADY) {
return rc;
}
if (rc == 0) {
/* Validate response */
if ((resp_group != WH_MESSAGE_GROUP_KEY) ||
Expand All @@ -1595,6 +1662,15 @@ int wh_Client_KeyExportDmaResponse(whClientContext* c, uint8_t* label,
}
}
}

/* POST cleanup: copy results back and release the mapping; surface a POST
* failure if the operation otherwise succeeded. */
{
int postRc = wh_Client_DmaAsyncPost(c, &c->dma.asyncCtx.buf);
if (rc == WH_ERROR_OK) {
rc = postRc;
}
}
return rc;
}

Expand All @@ -1616,26 +1692,57 @@ int wh_Client_KeyExportPublicDmaRequest(whClientContext* c, whKeyId keyId,
uint16_t algo, void* keyAddr,
uint16_t keySz)
{
whMessageKeystore_ExportPublicDmaRequest* req = NULL;
whMessageKeystore_ExportPublicDmaRequest* req = NULL;
uintptr_t keyAddrPtr = 0;
int ret = WH_ERROR_OK;
int keyAddrAcquired = 0;

if (c == NULL || keyId == WH_KEYID_ERASED) {
return WH_ERROR_BADARGS;
}
/* Fail fast if busy: don't acquire a mapping a rejected send would leak. */
if (wh_CommClient_IsRequestPending(c->comm) == 1) {
return WH_ERROR_REQUEST_PENDING;
}

req =
(whMessageKeystore_ExportPublicDmaRequest*)wh_CommClient_GetDataPtr(
c->comm);
if (req == NULL) {
return WH_ERROR_BADARGS;
}

req->id = keyId;
req->algo = algo;
req->key.addr = (uint64_t)((uintptr_t)keyAddr);
req->key.addr = 0;
req->key.sz = keySz;

return wh_Client_SendRequest(c, WH_MESSAGE_GROUP_KEY,
WH_KEY_EXPORT_PUBLIC_DMA, sizeof(*req),
(uint8_t*)req);
/* Clear the slot up front so a skipped PRE leaves nothing for POST. */
c->dma.asyncCtx.buf.sz = 0;

/* PRE-translate the output public key buffer; see KeyExportDmaRequest. */
ret = wh_Client_DmaProcessClientAddress(
c, (uintptr_t)keyAddr, (void**)&keyAddrPtr, keySz,
WH_DMA_OPER_CLIENT_WRITE_PRE, (whDmaFlags){0});
if (ret == WH_ERROR_OK) {
keyAddrAcquired = 1;
req->key.addr = (uint64_t)keyAddrPtr;
c->dma.asyncCtx.buf.xformedAddr = keyAddrPtr;
c->dma.asyncCtx.buf.clientAddr = (uintptr_t)keyAddr;
c->dma.asyncCtx.buf.sz = keySz;
c->dma.asyncCtx.buf.postOper = WH_DMA_OPER_CLIENT_WRITE_POST;
}

if (ret == WH_ERROR_OK) {
ret = wh_Client_SendRequest(c, WH_MESSAGE_GROUP_KEY,
WH_KEY_EXPORT_PUBLIC_DMA, sizeof(*req),
(uint8_t*)req);
}

if (ret != WH_ERROR_OK && keyAddrAcquired) {
(void)wh_Client_DmaAsyncPost(c, &c->dma.asyncCtx.buf);
}
return ret;
}

int wh_Client_KeyExportPublicDmaResponse(whClientContext* c, uint8_t* label,
Expand All @@ -1660,6 +1767,9 @@ int wh_Client_KeyExportPublicDmaResponse(whClientContext* c, uint8_t* label,

rc = wh_Client_RecvResponse(c, &resp_group, &resp_action, &resp_size,
(uint8_t*)resp);
if (rc == WH_ERROR_NOTREADY) {
return rc;
}
if (rc == 0) {
if (resp_size != sizeof(*resp)) {
rc = WH_ERROR_ABORTED;
Expand All @@ -1679,6 +1789,14 @@ int wh_Client_KeyExportPublicDmaResponse(whClientContext* c, uint8_t* label,
}
}
}

/* POST cleanup; see KeyExportDmaResponse. */
{
int postRc = wh_Client_DmaAsyncPost(c, &c->dma.asyncCtx.buf);
if (rc == WH_ERROR_OK) {
rc = postRc;
}
}
return rc;
}

Expand Down
23 changes: 22 additions & 1 deletion src/wh_client_dma.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@
int wh_Client_DmaRegisterAllowList(whClientContext* client,
const whDmaAddrAllowList* allowlist)
{
if (NULL == client || NULL == allowlist) {
if (NULL == client) {
return WH_ERROR_BADARGS;
}

/* A NULL allowlist clears any previously registered list (no enforcement),
* symmetric with wh_Client_DmaRegisterCb(NULL). */
client->dma.dmaAddrAllowList = allowlist;

return WH_ERROR_OK;
Expand Down Expand Up @@ -94,4 +96,23 @@ int wh_Client_DmaProcessClientAddress(whClientContext* client,
}
return rc;
}

int wh_Client_DmaAsyncPost(whClientContext* client, whClientDmaAsyncBuf* buf)
{
int rc;
uintptr_t addr;

if (client == NULL || buf == NULL || buf->sz == 0) {
return WH_ERROR_OK;
}

addr = buf->xformedAddr;
rc = wh_Client_DmaProcessClientAddress(client, buf->clientAddr,
(void**)&addr, (size_t)buf->sz,
buf->postOper, (whDmaFlags){0});
/* Clear the slot even on failure so a later Response cannot re-run the
* POST; the failure is returned to the caller. */
buf->sz = 0;
return rc;
}
#endif /* WOLFHSM_CFG_DMA */
Loading
Loading