3737#include <linux/vfio.h>
3838#include <linux/workqueue.h>
3939#include <linux/notifier.h>
40+ #include <linux/mm_inline.h>
4041#include "vfio.h"
4142
4243#define DRIVER_VERSION "0.2"
@@ -92,6 +93,7 @@ struct vfio_dma {
9293 bool iommu_mapped ;
9394 bool lock_cap ; /* capable(CAP_IPC_LOCK) */
9495 bool vaddr_invalid ;
96+ bool has_rsvd ; /* has 1 or more rsvd pfns */
9597 struct task_struct * task ;
9698 struct rb_root pfn_list ; /* Ex-user pinned pfn list */
9799 unsigned long * bitmap ;
@@ -318,24 +320,35 @@ static void vfio_dma_bitmap_free_all(struct vfio_iommu *iommu)
318320/*
319321 * Helper Functions for host iova-pfn list
320322 */
321- static struct vfio_pfn * vfio_find_vpfn (struct vfio_dma * dma , dma_addr_t iova )
323+
324+ /*
325+ * Find the highest vfio_pfn that overlapping the range
326+ * [iova_start, iova_end) in rb tree.
327+ */
328+ static struct vfio_pfn * vfio_find_vpfn_range (struct vfio_dma * dma ,
329+ dma_addr_t iova_start , dma_addr_t iova_end )
322330{
323331 struct vfio_pfn * vpfn ;
324332 struct rb_node * node = dma -> pfn_list .rb_node ;
325333
326334 while (node ) {
327335 vpfn = rb_entry (node , struct vfio_pfn , node );
328336
329- if (iova < vpfn -> iova )
337+ if (iova_end <= vpfn -> iova )
330338 node = node -> rb_left ;
331- else if (iova > vpfn -> iova )
339+ else if (iova_start > vpfn -> iova )
332340 node = node -> rb_right ;
333341 else
334342 return vpfn ;
335343 }
336344 return NULL ;
337345}
338346
347+ static inline struct vfio_pfn * vfio_find_vpfn (struct vfio_dma * dma , dma_addr_t iova )
348+ {
349+ return vfio_find_vpfn_range (dma , iova , iova + 1 );
350+ }
351+
339352static void vfio_link_pfn (struct vfio_dma * dma ,
340353 struct vfio_pfn * new )
341354{
@@ -614,6 +627,39 @@ static long vaddr_get_pfns(struct mm_struct *mm, unsigned long vaddr,
614627 return ret ;
615628}
616629
630+
631+ static long vpfn_pages (struct vfio_dma * dma ,
632+ dma_addr_t iova_start , long nr_pages )
633+ {
634+ dma_addr_t iova_end = iova_start + (nr_pages << PAGE_SHIFT );
635+ struct vfio_pfn * top = vfio_find_vpfn_range (dma , iova_start , iova_end );
636+ long ret = 1 ;
637+ struct vfio_pfn * vpfn ;
638+ struct rb_node * prev ;
639+ struct rb_node * next ;
640+
641+ if (likely (!top ))
642+ return 0 ;
643+
644+ prev = next = & top -> node ;
645+
646+ while ((prev = rb_prev (prev ))) {
647+ vpfn = rb_entry (prev , struct vfio_pfn , node );
648+ if (vpfn -> iova < iova_start )
649+ break ;
650+ ret ++ ;
651+ }
652+
653+ while ((next = rb_next (next ))) {
654+ vpfn = rb_entry (next , struct vfio_pfn , node );
655+ if (vpfn -> iova >= iova_end )
656+ break ;
657+ ret ++ ;
658+ }
659+
660+ return ret ;
661+ }
662+
617663/*
618664 * Attempt to pin pages. We really don't want to track all the pfns and
619665 * the iommu can only map chunks of consecutive pfns anyway, so get the
@@ -687,32 +733,47 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
687733 * and rsvd here, and therefore continues to use the batch.
688734 */
689735 while (true) {
736+ long nr_pages , acct_pages = 0 ;
737+
690738 if (pfn != * pfn_base + pinned ||
691739 rsvd != is_invalid_reserved_pfn (pfn ))
692740 goto out ;
693741
742+ /*
743+ * Using GUP with the FOLL_LONGTERM in
744+ * vaddr_get_pfns() will not return invalid
745+ * or reserved pages.
746+ */
747+ nr_pages = num_pages_contiguous (
748+ & batch -> pages [batch -> offset ],
749+ batch -> size );
750+ if (!rsvd ) {
751+ acct_pages = nr_pages ;
752+ acct_pages -= vpfn_pages (dma , iova , nr_pages );
753+ }
754+
694755 /*
695756 * Reserved pages aren't counted against the user,
696757 * externally pinned pages are already counted against
697758 * the user.
698759 */
699- if (! rsvd && ! vfio_find_vpfn ( dma , iova ) ) {
760+ if (acct_pages ) {
700761 if (!dma -> lock_cap &&
701- mm -> locked_vm + lock_acct + 1 > limit ) {
762+ mm -> locked_vm + lock_acct + acct_pages > limit ) {
702763 pr_warn ("%s: RLIMIT_MEMLOCK (%ld) exceeded\n" ,
703764 __func__ , limit << PAGE_SHIFT );
704765 ret = - ENOMEM ;
705766 goto unpin_out ;
706767 }
707- lock_acct ++ ;
768+ lock_acct += acct_pages ;
708769 }
709770
710- pinned ++ ;
711- npage -- ;
712- vaddr += PAGE_SIZE ;
713- iova += PAGE_SIZE ;
714- batch -> offset ++ ;
715- batch -> size -- ;
771+ pinned += nr_pages ;
772+ npage -= nr_pages ;
773+ vaddr += PAGE_SIZE * nr_pages ;
774+ iova += PAGE_SIZE * nr_pages ;
775+ batch -> offset += nr_pages ;
776+ batch -> size -= nr_pages ;
716777
717778 if (!batch -> size )
718779 break ;
@@ -722,6 +783,7 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
722783 }
723784
724785out :
786+ dma -> has_rsvd |= rsvd ;
725787 ret = vfio_lock_acct (dma , lock_acct , false);
726788
727789unpin_out :
@@ -738,21 +800,29 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
738800 return pinned ;
739801}
740802
803+ static inline void put_valid_unreserved_pfns (unsigned long start_pfn ,
804+ unsigned long npage , int prot )
805+ {
806+ unpin_user_page_range_dirty_lock (pfn_to_page (start_pfn ), npage ,
807+ prot & IOMMU_WRITE );
808+ }
809+
741810static long vfio_unpin_pages_remote (struct vfio_dma * dma , dma_addr_t iova ,
742811 unsigned long pfn , unsigned long npage ,
743812 bool do_accounting )
744813{
745- long unlocked = 0 , locked = 0 ;
746- long i ;
814+ long unlocked = 0 , locked = vpfn_pages (dma , iova , npage );
747815
748- for (i = 0 ; i < npage ; i ++ , iova += PAGE_SIZE ) {
749- if (put_pfn (pfn ++ , dma -> prot )) {
750- unlocked ++ ;
751- if (vfio_find_vpfn (dma , iova ))
752- locked ++ ;
753- }
754- }
816+ if (dma -> has_rsvd ) {
817+ unsigned long i ;
755818
819+ for (i = 0 ; i < npage ; i ++ )
820+ if (put_pfn (pfn ++ , dma -> prot ))
821+ unlocked ++ ;
822+ } else {
823+ put_valid_unreserved_pfns (pfn , npage , dma -> prot );
824+ unlocked = npage ;
825+ }
756826 if (do_accounting )
757827 vfio_lock_acct (dma , locked - unlocked , true);
758828
0 commit comments