1414#include <net/sock.h>
1515#include <net/tcp.h>
1616#include <net/tls.h>
17+ #include <net/tls_prot.h>
1718#include <net/handshake.h>
1819#include <linux/inet.h>
1920#include <linux/llist.h>
@@ -118,6 +119,7 @@ struct nvmet_tcp_cmd {
118119 u32 pdu_len ;
119120 u32 pdu_recv ;
120121 int sg_idx ;
122+ char recv_cbuf [CMSG_LEN (sizeof (char ))];
121123 struct msghdr recv_msg ;
122124 struct bio_vec * iov ;
123125 u32 flags ;
@@ -1121,20 +1123,65 @@ static inline bool nvmet_tcp_pdu_valid(u8 type)
11211123 return false;
11221124}
11231125
1126+ static int nvmet_tcp_tls_record_ok (struct nvmet_tcp_queue * queue ,
1127+ struct msghdr * msg , char * cbuf )
1128+ {
1129+ struct cmsghdr * cmsg = (struct cmsghdr * )cbuf ;
1130+ u8 ctype , level , description ;
1131+ int ret = 0 ;
1132+
1133+ ctype = tls_get_record_type (queue -> sock -> sk , cmsg );
1134+ switch (ctype ) {
1135+ case 0 :
1136+ break ;
1137+ case TLS_RECORD_TYPE_DATA :
1138+ break ;
1139+ case TLS_RECORD_TYPE_ALERT :
1140+ tls_alert_recv (queue -> sock -> sk , msg , & level , & description );
1141+ if (level == TLS_ALERT_LEVEL_FATAL ) {
1142+ pr_err ("queue %d: TLS Alert desc %u\n" ,
1143+ queue -> idx , description );
1144+ ret = - ENOTCONN ;
1145+ } else {
1146+ pr_warn ("queue %d: TLS Alert desc %u\n" ,
1147+ queue -> idx , description );
1148+ ret = - EAGAIN ;
1149+ }
1150+ break ;
1151+ default :
1152+ /* discard this record type */
1153+ pr_err ("queue %d: TLS record %d unhandled\n" ,
1154+ queue -> idx , ctype );
1155+ ret = - EAGAIN ;
1156+ break ;
1157+ }
1158+ return ret ;
1159+ }
1160+
11241161static int nvmet_tcp_try_recv_pdu (struct nvmet_tcp_queue * queue )
11251162{
11261163 struct nvme_tcp_hdr * hdr = & queue -> pdu .cmd .hdr ;
1127- int len ;
1164+ int len , ret ;
11281165 struct kvec iov ;
1166+ char cbuf [CMSG_LEN (sizeof (char ))] = {};
11291167 struct msghdr msg = { .msg_flags = MSG_DONTWAIT };
11301168
11311169recv :
11321170 iov .iov_base = (void * )& queue -> pdu + queue -> offset ;
11331171 iov .iov_len = queue -> left ;
1172+ if (queue -> tls_pskid ) {
1173+ msg .msg_control = cbuf ;
1174+ msg .msg_controllen = sizeof (cbuf );
1175+ }
11341176 len = kernel_recvmsg (queue -> sock , & msg , & iov , 1 ,
11351177 iov .iov_len , msg .msg_flags );
11361178 if (unlikely (len < 0 ))
11371179 return len ;
1180+ if (queue -> tls_pskid ) {
1181+ ret = nvmet_tcp_tls_record_ok (queue , & msg , cbuf );
1182+ if (ret < 0 )
1183+ return ret ;
1184+ }
11381185
11391186 queue -> offset += len ;
11401187 queue -> left -= len ;
@@ -1187,16 +1234,22 @@ static void nvmet_tcp_prep_recv_ddgst(struct nvmet_tcp_cmd *cmd)
11871234static int nvmet_tcp_try_recv_data (struct nvmet_tcp_queue * queue )
11881235{
11891236 struct nvmet_tcp_cmd * cmd = queue -> cmd ;
1190- int ret ;
1237+ int len , ret ;
11911238
11921239 while (msg_data_left (& cmd -> recv_msg )) {
1193- ret = sock_recvmsg (cmd -> queue -> sock , & cmd -> recv_msg ,
1240+ len = sock_recvmsg (cmd -> queue -> sock , & cmd -> recv_msg ,
11941241 cmd -> recv_msg .msg_flags );
1195- if (ret <= 0 )
1196- return ret ;
1242+ if (len <= 0 )
1243+ return len ;
1244+ if (queue -> tls_pskid ) {
1245+ ret = nvmet_tcp_tls_record_ok (cmd -> queue ,
1246+ & cmd -> recv_msg , cmd -> recv_cbuf );
1247+ if (ret < 0 )
1248+ return ret ;
1249+ }
11971250
1198- cmd -> pdu_recv += ret ;
1199- cmd -> rbytes_done += ret ;
1251+ cmd -> pdu_recv += len ;
1252+ cmd -> rbytes_done += len ;
12001253 }
12011254
12021255 if (queue -> data_digest ) {
@@ -1214,20 +1267,30 @@ static int nvmet_tcp_try_recv_data(struct nvmet_tcp_queue *queue)
12141267static int nvmet_tcp_try_recv_ddgst (struct nvmet_tcp_queue * queue )
12151268{
12161269 struct nvmet_tcp_cmd * cmd = queue -> cmd ;
1217- int ret ;
1270+ int ret , len ;
1271+ char cbuf [CMSG_LEN (sizeof (char ))] = {};
12181272 struct msghdr msg = { .msg_flags = MSG_DONTWAIT };
12191273 struct kvec iov = {
12201274 .iov_base = (void * )& cmd -> recv_ddgst + queue -> offset ,
12211275 .iov_len = queue -> left
12221276 };
12231277
1224- ret = kernel_recvmsg (queue -> sock , & msg , & iov , 1 ,
1278+ if (queue -> tls_pskid ) {
1279+ msg .msg_control = cbuf ;
1280+ msg .msg_controllen = sizeof (cbuf );
1281+ }
1282+ len = kernel_recvmsg (queue -> sock , & msg , & iov , 1 ,
12251283 iov .iov_len , msg .msg_flags );
1226- if (unlikely (ret < 0 ))
1227- return ret ;
1284+ if (unlikely (len < 0 ))
1285+ return len ;
1286+ if (queue -> tls_pskid ) {
1287+ ret = nvmet_tcp_tls_record_ok (queue , & msg , cbuf );
1288+ if (ret < 0 )
1289+ return ret ;
1290+ }
12281291
1229- queue -> offset += ret ;
1230- queue -> left -= ret ;
1292+ queue -> offset += len ;
1293+ queue -> left -= len ;
12311294 if (queue -> left )
12321295 return - EAGAIN ;
12331296
@@ -1407,6 +1470,10 @@ static int nvmet_tcp_alloc_cmd(struct nvmet_tcp_queue *queue,
14071470 if (!c -> r2t_pdu )
14081471 goto out_free_data ;
14091472
1473+ if (queue -> state == NVMET_TCP_Q_TLS_HANDSHAKE ) {
1474+ c -> recv_msg .msg_control = c -> recv_cbuf ;
1475+ c -> recv_msg .msg_controllen = sizeof (c -> recv_cbuf );
1476+ }
14101477 c -> recv_msg .msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL ;
14111478
14121479 list_add_tail (& c -> entry , & queue -> free_list );
0 commit comments