Skip to content

Commit fcbeee4

Browse files
committed
add: ClArray.upperBoundWithValue
1 parent 970ffa0 commit fcbeee4

12 files changed

Lines changed: 272 additions & 189 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -268,9 +268,11 @@ module ClArray =
268268

269269
result
270270

271-
let firstOccurrence2 clContext = getUniqueBitmap2General firstOccurrence clContext
271+
let firstOccurrence2 clContext =
272+
getUniqueBitmap2General firstOccurrence clContext
272273

273-
let lastOccurrence2 clContext = getUniqueBitmap2General lastOccurrence clContext
274+
let lastOccurrence2 clContext =
275+
getUniqueBitmap2General lastOccurrence clContext
274276

275277
///<description>Remove duplicates form the given array.</description>
276278
///<param name="clContext">Computational context</param>
@@ -694,33 +696,39 @@ module ClArray =
694696
else
695697
None
696698

697-
let upperBound<'a when 'a : equality and 'a : comparison> (clContext: ClContext) workGroupSize =
699+
let private bound<'a, 'b when 'a: equality and 'a: comparison>
700+
(lowerBound: Expr<(int -> 'a -> ClArray<'a> -> 'b)>)
701+
(clContext: ClContext)
702+
workGroupSize
703+
=
698704

699705
let kernel =
700-
<@ fun (ndRange: Range1D) length (values: ClArray<'a>) (value: ClCell<'a>) (result: ClCell<int>) ->
706+
<@ fun (ndRange: Range1D) length (values: ClArray<'a>) (value: ClCell<'a>) (result: ClCell<'b>) ->
701707

702708
let value = value.Value
703709
let gid = ndRange.GlobalID0
704710

705711
if gid = 0 then
706712

707-
result.Value <-
708-
(%Search.Bin.lowerBound 0) length value values @>
713+
result.Value <- (%lowerBound) length value values @>
709714

710715
let program = clContext.Compile(kernel)
711716

712717
fun (processor: MailboxProcessor<_>) (values: ClArray<'a>) (value: ClCell<'a>) ->
713-
let result = clContext.CreateClCell 0
718+
let result =
719+
clContext.CreateClCell Unchecked.defaultof<'b>
714720

715721
let kernel = program.GetKernel()
716722

717-
let ndRange =
718-
Range1D.CreateValid(1, workGroupSize)
723+
let ndRange = Range1D.CreateValid(1, workGroupSize)
719724

720725
processor.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange values.Length values value result))
721726
processor.Post(Msg.CreateRunMsg<_, _> kernel)
722727

723728
result
724729

730+
let upperBoundAndValue<'a when 'a: comparison> clContext =
731+
bound<'a, int * 'a> Search.Bin.lowerBoundAndValue clContext
725732

726-
733+
let upperBound<'a when 'a: comparison> clContext =
734+
bound<'a, int> Search.Bin.lowerBound clContext

src/GraphBLAS-sharp.Backend/Matrix/CSR/Matrix.fs

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ module Matrix =
2222

2323
if gid < columnsLength then
2424
let result =
25-
(%Search.Bin.lowerBound 0) pointersLength gid pointers
25+
(%Search.Bin.lowerBound) pointersLength gid pointers
2626

27-
results.[gid] <- result - 1 @>
27+
results.[gid] <- result - 1 @>
2828

2929
let program = clContext.Compile kernel
3030

@@ -38,14 +38,16 @@ module Matrix =
3838
let ndRange =
3939
Range1D.CreateValid(matrix.Columns.Length, workGroupSize)
4040

41-
processor.Post(Msg.MsgSetArguments(
42-
fun () ->
43-
kernel.KernelFunc
44-
ndRange
45-
matrix.Columns.Length
46-
matrix.RowPointers.Length
47-
matrix.RowPointers
48-
rows))
41+
processor.Post(
42+
Msg.MsgSetArguments
43+
(fun () ->
44+
kernel.KernelFunc
45+
ndRange
46+
matrix.Columns.Length
47+
matrix.RowPointers.Length
48+
matrix.RowPointers
49+
rows)
50+
)
4951

5052
processor.Post(Msg.CreateRunMsg<_, _> kernel)
5153

@@ -63,9 +65,9 @@ module Matrix =
6365

6466
if gid < resultLength then
6567
let result =
66-
(%Search.Bin.lowerBound 0) pointersLength shiftedId pointers
68+
(%Search.Bin.lowerBound) pointersLength shiftedId pointers
6769

68-
results.[gid] <- result - 1 @>
70+
results.[gid] <- result - 1 @>
6971

7072
let program = clContext.Compile kernel
7173

@@ -86,7 +88,9 @@ module Matrix =
8688
// extract rows
8789
let rowPointers = matrix.RowPointers.ToHost processor
8890

89-
let resultLength = rowPointers.[startIndex + count] - rowPointers.[startIndex]
91+
let resultLength =
92+
rowPointers.[startIndex + count]
93+
- rowPointers.[startIndex]
9094

9195
let rows =
9296
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
@@ -96,15 +100,17 @@ module Matrix =
96100
let ndRange =
97101
Range1D.CreateValid(matrix.Columns.Length, workGroupSize)
98102

99-
processor.Post(Msg.MsgSetArguments(
100-
fun () ->
101-
kernel.KernelFunc
102-
ndRange
103-
resultLength
104-
startIndex
105-
matrix.RowPointers.Length
106-
matrix.RowPointers
107-
rows))
103+
processor.Post(
104+
Msg.MsgSetArguments
105+
(fun () ->
106+
kernel.KernelFunc
107+
ndRange
108+
resultLength
109+
startIndex
110+
matrix.RowPointers.Length
111+
matrix.RowPointers
112+
rows)
113+
)
108114

109115
processor.Post(Msg.CreateRunMsg<_, _> kernel)
110116

@@ -130,15 +136,15 @@ module Matrix =
130136
Values = values }
131137

132138
let toCOO (clContext: ClContext) workGroupSize =
133-
let prepare = expandRowPointers clContext workGroupSize
139+
let prepare =
140+
expandRowPointers clContext workGroupSize
134141

135142
let copy = ClArray.copy clContext workGroupSize
136143

137144
let copyData = ClArray.copy clContext workGroupSize
138145

139146
fun (processor: MailboxProcessor<_>) allocationMode (matrix: ClMatrix.CSR<'a>) ->
140-
let rows =
141-
prepare processor allocationMode matrix
147+
let rows = prepare processor allocationMode matrix
142148

143149
let cols =
144150
copy processor allocationMode matrix.Columns
@@ -154,11 +160,11 @@ module Matrix =
154160
Values = values }
155161

156162
let toCOOInPlace (clContext: ClContext) workGroupSize =
157-
let prepare = expandRowPointers clContext workGroupSize
163+
let prepare =
164+
expandRowPointers clContext workGroupSize
158165

159166
fun (processor: MailboxProcessor<_>) allocationMode (matrix: ClMatrix.CSR<'a>) ->
160-
let rows =
161-
prepare processor allocationMode matrix
167+
let rows = prepare processor allocationMode matrix
162168

163169
processor.Post(Msg.CreateFreeMsg(matrix.RowPointers))
164170

0 commit comments

Comments
 (0)