@@ -331,7 +331,7 @@ module Operations =
331331 | _ -> failwith " Not implemented yet"
332332
333333 /// <summary >
334- /// CSR Matrix - sparse vector multiplication. Optimized for bool OR and AND operations.
334+ /// CSR Matrix - sparse vector multiplication. Optimized for bool OR and AND operations by skipping reduction stage .
335335 /// </summary >
336336 /// <param name =" add " >Type of binary function to reduce entries.</param >
337337 /// <param name =" mul " >Type of binary function to combine entries.</param >
@@ -352,6 +352,50 @@ module Operations =
352352 | ClMatrix.CSR m, ClVector.Sparse v -> Option.map ClVector.Sparse ( run queue m v)
353353 | _ -> failwith " Not implemented yet"
354354
355+ /// <summary >
356+ /// CSR Matrix - sparse vector multiplication with mask. Mask is complemented.
357+ /// </summary >
358+ /// <param name =" add " >Type of binary function to reduce entries.</param >
359+ /// <param name =" mul " >Type of binary function to combine entries.</param >
360+ /// <param name =" clContext " >OpenCL context.</param >
361+ /// <param name =" workGroupSize " >Should be a power of 2 and greater than 1.</param >
362+ let SpMSpVMasked
363+ ( add : Expr < 'c option -> 'c option -> 'c option >)
364+ ( mul : Expr < 'a option -> 'b option -> 'c option >)
365+ ( clContext : ClContext )
366+ workGroupSize
367+ =
368+
369+ let run =
370+ SpMSpV.Masked.runMasked add mul clContext workGroupSize
371+
372+ fun ( queue : RawCommandQueue ) ( matrix : ClMatrix < 'a >) ( vector : ClVector < 'b >) ( mask : ClVector < 'd >) ->
373+ match matrix, vector, mask with
374+ | ClMatrix.CSR m, ClVector.Sparse v, ClVector.Dense mask -> Option.map ClVector.Sparse ( run queue m v mask)
375+ | _ -> failwith " Not implemented yet"
376+
377+ /// <summary >
378+ /// CSR Matrix - sparse vector multiplication with mask. Mask is complemented. Optimized for bool OR and AND operations by skipping reduction stage.
379+ /// </summary >
380+ /// <param name =" add " >Type of binary function to reduce entries.</param >
381+ /// <param name =" mul " >Type of binary function to combine entries.</param >
382+ /// <param name =" clContext " >OpenCL context.</param >
383+ /// <param name =" workGroupSize " >Should be a power of 2 and greater than 1.</param >
384+ let SpMSpVMaskedBool
385+ ( add : Expr < bool option -> bool option -> bool option >)
386+ ( mul : Expr < bool option -> bool option -> bool option >)
387+ ( clContext : ClContext )
388+ workGroupSize
389+ =
390+
391+ let run =
392+ SpMSpV.Masked.runMaskedBoolStandard add mul clContext workGroupSize
393+
394+ fun ( queue : RawCommandQueue ) ( matrix : ClMatrix < 'a >) ( vector : ClVector < 'b >) ( mask : ClVector < 'd >) ->
395+ match matrix, vector, mask with
396+ | ClMatrix.CSR m, ClVector.Sparse v, ClVector.Dense mask -> Option.map ClVector.Sparse ( run queue m v mask)
397+ | _ -> failwith " Not implemented yet"
398+
355399 /// <summary >
356400 /// CSR Matrix - sparse vector multiplication.
357401 /// </summary >
0 commit comments