@@ -78,12 +78,33 @@ static struct iommu_domain *get_domain_for_iopf(struct device *dev,
7878 return domain ;
7979}
8080
81+ /* Non-last request of a group. Postpone until the last one. */
82+ static int report_partial_fault (struct iommu_fault_param * fault_param ,
83+ struct iommu_fault * fault )
84+ {
85+ struct iopf_fault * iopf ;
86+
87+ iopf = kzalloc (sizeof (* iopf ), GFP_KERNEL );
88+ if (!iopf )
89+ return - ENOMEM ;
90+
91+ iopf -> fault = * fault ;
92+
93+ mutex_lock (& fault_param -> lock );
94+ list_add (& iopf -> list , & fault_param -> partial );
95+ mutex_unlock (& fault_param -> lock );
96+
97+ return 0 ;
98+ }
99+
81100/**
82- * iommu_handle_iopf - IO Page Fault handler
83- * @fault: fault event
84- * @iopf_param: the fault parameter of the device.
101+ * iommu_report_device_fault() - Report fault event to device driver
102+ * @dev: the device
103+ * @evt: fault event data
85104 *
86- * Add a fault to the device workqueue, to be handled by mm.
105+ * Called by IOMMU drivers when a fault is detected, typically in a threaded IRQ
106+ * handler. When this function fails and the fault is recoverable, it is the
107+ * caller's responsibility to complete the fault.
87108 *
88109 * This module doesn't handle PCI PASID Stop Marker; IOMMU drivers must discard
89110 * them before reporting faults. A PASID Stop Marker (LRW = 0b100) doesn't
@@ -118,34 +139,37 @@ static struct iommu_domain *get_domain_for_iopf(struct device *dev,
118139 *
119140 * Return: 0 on success and <0 on error.
120141 */
121- static int iommu_handle_iopf (struct iommu_fault * fault ,
122- struct iommu_fault_param * iopf_param )
142+ int iommu_report_device_fault (struct device * dev , struct iopf_fault * evt )
123143{
124- int ret ;
125- struct iopf_group * group ;
126- struct iommu_domain * domain ;
144+ struct iommu_fault * fault = & evt -> fault ;
145+ struct iommu_fault_param * iopf_param ;
127146 struct iopf_fault * iopf , * next ;
128- struct device * dev = iopf_param -> dev ;
129-
130- lockdep_assert_held ( & iopf_param -> lock ) ;
147+ struct iommu_domain * domain ;
148+ struct iopf_group * group ;
149+ int ret ;
131150
132151 if (fault -> type != IOMMU_FAULT_PAGE_REQ )
133- /* Not a recoverable page fault */
134152 return - EOPNOTSUPP ;
135153
136- if (!(fault -> prm .flags & IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE )) {
137- iopf = kzalloc (sizeof (* iopf ), GFP_KERNEL );
138- if (!iopf )
139- return - ENOMEM ;
140-
141- iopf -> fault = * fault ;
154+ iopf_param = iopf_get_dev_fault_param (dev );
155+ if (!iopf_param )
156+ return - ENODEV ;
142157
143- /* Non-last request of a group. Postpone until the last one */
144- list_add (& iopf -> list , & iopf_param -> partial );
158+ if (!(fault -> prm .flags & IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE )) {
159+ ret = report_partial_fault (iopf_param , fault );
160+ iopf_put_dev_fault_param (iopf_param );
145161
146- return 0 ;
162+ return ret ;
147163 }
148164
165+ /*
166+ * This is the last page fault of a group. Allocate an iopf group and
167+ * pass it to domain's page fault handler. The group holds a reference
168+ * count of the fault parameter. It will be released after response or
169+ * error path of this function. If an error is returned, the caller
170+ * will send a response to the hardware. We need to clean up before
171+ * leaving, otherwise partial faults will be stuck.
172+ */
149173 domain = get_domain_for_iopf (dev , fault );
150174 if (!domain ) {
151175 ret = - EINVAL ;
@@ -154,157 +178,52 @@ static int iommu_handle_iopf(struct iommu_fault *fault,
154178
155179 group = kzalloc (sizeof (* group ), GFP_KERNEL );
156180 if (!group ) {
157- /*
158- * The caller will send a response to the hardware. But we do
159- * need to clean up before leaving, otherwise partial faults
160- * will be stuck.
161- */
162181 ret = - ENOMEM ;
163182 goto cleanup_partial ;
164183 }
165184
166185 group -> fault_param = iopf_param ;
167186 group -> last_fault .fault = * fault ;
168187 INIT_LIST_HEAD (& group -> faults );
188+ INIT_LIST_HEAD (& group -> pending_node );
169189 group -> domain = domain ;
170190 list_add (& group -> last_fault .list , & group -> faults );
171191
172192 /* See if we have partial faults for this group */
193+ mutex_lock (& iopf_param -> lock );
173194 list_for_each_entry_safe (iopf , next , & iopf_param -> partial , list ) {
174195 if (iopf -> fault .prm .grpid == fault -> prm .grpid )
175196 /* Insert *before* the last fault */
176197 list_move (& iopf -> list , & group -> faults );
177198 }
178-
199+ list_add ( & group -> pending_node , & iopf_param -> faults );
179200 mutex_unlock (& iopf_param -> lock );
201+
180202 ret = domain -> iopf_handler (group );
181- mutex_lock (& iopf_param -> lock );
182- if (ret )
203+ if (ret ) {
204+ mutex_lock (& iopf_param -> lock );
205+ list_del_init (& group -> pending_node );
206+ mutex_unlock (& iopf_param -> lock );
183207 iopf_free_group (group );
208+ }
184209
185210 return ret ;
211+
186212cleanup_partial :
213+ mutex_lock (& iopf_param -> lock );
187214 list_for_each_entry_safe (iopf , next , & iopf_param -> partial , list ) {
188215 if (iopf -> fault .prm .grpid == fault -> prm .grpid ) {
189216 list_del (& iopf -> list );
190217 kfree (iopf );
191218 }
192219 }
193- return ret ;
194- }
195-
196- /**
197- * iommu_report_device_fault() - Report fault event to device driver
198- * @dev: the device
199- * @evt: fault event data
200- *
201- * Called by IOMMU drivers when a fault is detected, typically in a threaded IRQ
202- * handler. When this function fails and the fault is recoverable, it is the
203- * caller's responsibility to complete the fault.
204- *
205- * Return 0 on success, or an error.
206- */
207- int iommu_report_device_fault (struct device * dev , struct iopf_fault * evt )
208- {
209- bool last_prq = evt -> fault .type == IOMMU_FAULT_PAGE_REQ &&
210- (evt -> fault .prm .flags & IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE );
211- struct iommu_fault_param * fault_param ;
212- struct iopf_fault * evt_pending ;
213- int ret ;
214-
215- fault_param = iopf_get_dev_fault_param (dev );
216- if (!fault_param )
217- return - EINVAL ;
218-
219- mutex_lock (& fault_param -> lock );
220- if (last_prq ) {
221- evt_pending = kmemdup (evt , sizeof (struct iopf_fault ),
222- GFP_KERNEL );
223- if (!evt_pending ) {
224- ret = - ENOMEM ;
225- goto err_unlock ;
226- }
227- list_add_tail (& evt_pending -> list , & fault_param -> faults );
228- }
229-
230- ret = iommu_handle_iopf (& evt -> fault , fault_param );
231- if (ret )
232- goto err_free ;
233-
234- mutex_unlock (& fault_param -> lock );
235- /* The reference count of fault_param is now held by iopf_group. */
236- if (!last_prq )
237- iopf_put_dev_fault_param (fault_param );
238-
239- return 0 ;
240- err_free :
241- if (last_prq ) {
242- list_del (& evt_pending -> list );
243- kfree (evt_pending );
244- }
245- err_unlock :
246- mutex_unlock (& fault_param -> lock );
247- iopf_put_dev_fault_param (fault_param );
220+ mutex_unlock (& iopf_param -> lock );
221+ iopf_put_dev_fault_param (iopf_param );
248222
249223 return ret ;
250224}
251225EXPORT_SYMBOL_GPL (iommu_report_device_fault );
252226
253- static int iommu_page_response (struct iopf_group * group ,
254- struct iommu_page_response * msg )
255- {
256- bool needs_pasid ;
257- int ret = - EINVAL ;
258- struct iopf_fault * evt ;
259- struct iommu_fault_page_request * prm ;
260- struct device * dev = group -> fault_param -> dev ;
261- const struct iommu_ops * ops = dev_iommu_ops (dev );
262- bool has_pasid = msg -> flags & IOMMU_PAGE_RESP_PASID_VALID ;
263- struct iommu_fault_param * fault_param = group -> fault_param ;
264-
265- /* Only send response if there is a fault report pending */
266- mutex_lock (& fault_param -> lock );
267- if (list_empty (& fault_param -> faults )) {
268- dev_warn_ratelimited (dev , "no pending PRQ, drop response\n" );
269- goto done_unlock ;
270- }
271- /*
272- * Check if we have a matching page request pending to respond,
273- * otherwise return -EINVAL
274- */
275- list_for_each_entry (evt , & fault_param -> faults , list ) {
276- prm = & evt -> fault .prm ;
277- if (prm -> grpid != msg -> grpid )
278- continue ;
279-
280- /*
281- * If the PASID is required, the corresponding request is
282- * matched using the group ID, the PASID valid bit and the PASID
283- * value. Otherwise only the group ID matches request and
284- * response.
285- */
286- needs_pasid = prm -> flags & IOMMU_FAULT_PAGE_RESPONSE_NEEDS_PASID ;
287- if (needs_pasid && (!has_pasid || msg -> pasid != prm -> pasid ))
288- continue ;
289-
290- if (!needs_pasid && has_pasid ) {
291- /* No big deal, just clear it. */
292- msg -> flags &= ~IOMMU_PAGE_RESP_PASID_VALID ;
293- msg -> pasid = 0 ;
294- }
295-
296- ret = ops -> page_response (dev , evt , msg );
297- list_del (& evt -> list );
298- kfree (evt );
299- break ;
300- }
301-
302- done_unlock :
303- mutex_unlock (& fault_param -> lock );
304-
305- return ret ;
306- }
307-
308227/**
309228 * iopf_queue_flush_dev - Ensure that all queued faults have been processed
310229 * @dev: the endpoint whose faults need to be flushed.
@@ -346,18 +265,26 @@ EXPORT_SYMBOL_GPL(iopf_queue_flush_dev);
346265int iopf_group_response (struct iopf_group * group ,
347266 enum iommu_page_response_code status )
348267{
268+ struct iommu_fault_param * fault_param = group -> fault_param ;
349269 struct iopf_fault * iopf = & group -> last_fault ;
270+ struct device * dev = group -> fault_param -> dev ;
271+ const struct iommu_ops * ops = dev_iommu_ops (dev );
350272 struct iommu_page_response resp = {
351273 .pasid = iopf -> fault .prm .pasid ,
352274 .grpid = iopf -> fault .prm .grpid ,
353275 .code = status ,
354276 };
277+ int ret = - EINVAL ;
355278
356- if ((iopf -> fault .prm .flags & IOMMU_FAULT_PAGE_REQUEST_PASID_VALID ) &&
357- (iopf -> fault .prm .flags & IOMMU_FAULT_PAGE_RESPONSE_NEEDS_PASID ))
358- resp .flags = IOMMU_PAGE_RESP_PASID_VALID ;
279+ /* Only send response if there is a fault report pending */
280+ mutex_lock (& fault_param -> lock );
281+ if (!list_empty (& group -> pending_node )) {
282+ ret = ops -> page_response (dev , & group -> last_fault , & resp );
283+ list_del_init (& group -> pending_node );
284+ }
285+ mutex_unlock (& fault_param -> lock );
359286
360- return iommu_page_response ( group , & resp ) ;
287+ return ret ;
361288}
362289EXPORT_SYMBOL_GPL (iopf_group_response );
363290
@@ -468,8 +395,9 @@ EXPORT_SYMBOL_GPL(iopf_queue_add_device);
468395 */
469396void iopf_queue_remove_device (struct iopf_queue * queue , struct device * dev )
470397{
471- struct iopf_fault * iopf , * next ;
472- struct iommu_page_response resp ;
398+ struct iopf_fault * partial_iopf ;
399+ struct iopf_fault * next ;
400+ struct iopf_group * group , * temp ;
473401 struct dev_iommu * param = dev -> iommu ;
474402 struct iommu_fault_param * fault_param ;
475403 const struct iommu_ops * ops = dev_iommu_ops (dev );
@@ -483,21 +411,19 @@ void iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev)
483411 goto unlock ;
484412
485413 mutex_lock (& fault_param -> lock );
486- list_for_each_entry_safe (iopf , next , & fault_param -> partial , list )
487- kfree (iopf );
488-
489- list_for_each_entry_safe (iopf , next , & fault_param -> faults , list ) {
490- memset (& resp , 0 , sizeof (struct iommu_page_response ));
491- resp .pasid = iopf -> fault .prm .pasid ;
492- resp .grpid = iopf -> fault .prm .grpid ;
493- resp .code = IOMMU_PAGE_RESP_INVALID ;
414+ list_for_each_entry_safe (partial_iopf , next , & fault_param -> partial , list )
415+ kfree (partial_iopf );
494416
495- if (iopf -> fault .prm .flags & IOMMU_FAULT_PAGE_RESPONSE_NEEDS_PASID )
496- resp .flags = IOMMU_PAGE_RESP_PASID_VALID ;
417+ list_for_each_entry_safe (group , temp , & fault_param -> faults , pending_node ) {
418+ struct iopf_fault * iopf = & group -> last_fault ;
419+ struct iommu_page_response resp = {
420+ .pasid = iopf -> fault .prm .pasid ,
421+ .grpid = iopf -> fault .prm .grpid ,
422+ .code = IOMMU_PAGE_RESP_INVALID
423+ };
497424
498425 ops -> page_response (dev , iopf , & resp );
499- list_del (& iopf -> list );
500- kfree (iopf );
426+ list_del_init (& group -> pending_node );
501427 }
502428 mutex_unlock (& fault_param -> lock );
503429
0 commit comments