|
| 1 | +namespace GraphBLAS.FSharp.Backend.Vector |
| 2 | + |
| 3 | +open Brahma.FSharp |
| 4 | +open GraphBLAS.FSharp.Backend.Common |
| 5 | +open GraphBLAS.FSharp.Backend.Quotes |
| 6 | +open GraphBLAS.FSharp.Backend.Vector.Sparse |
| 7 | +open Microsoft.FSharp.Quotations |
| 8 | +open GraphBLAS.FSharp.Backend.Objects |
| 9 | +open GraphBLAS.FSharp.Backend.Objects.ClVector |
| 10 | +open GraphBLAS.FSharp.Backend.Objects.ClContext |
| 11 | +open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions |
| 12 | +open GraphBLAS.FSharp.Backend.Objects.ClCell |
| 13 | + |
| 14 | +module SpMSpV = |
| 15 | + |
| 16 | + //For v in vectorIndices collect R[v] and R[v + 1] |
| 17 | + let private collectRows (clContext: ClContext) workGroupSize = |
| 18 | + |
| 19 | + let collectRows = |
| 20 | + <@ fun (ndRange: Range1D) inputSize (vectorIndices: ClArray<int>) (rowOffsets: ClArray<int>) (resultArray: ClArray<int>) -> |
| 21 | + |
| 22 | + let i = ndRange.GlobalID0 |
| 23 | + |
| 24 | + //resultArray is twice vector size |
| 25 | + if i < (inputSize * 2) then |
| 26 | + let columnIndex = vectorIndices.[i / 2] |
| 27 | + |
| 28 | + if i % 2 = 0 then |
| 29 | + resultArray.[i] <- rowOffsets.[columnIndex] |
| 30 | + else |
| 31 | + resultArray.[i] <- rowOffsets.[columnIndex + 1] |
| 32 | + elif i = inputSize * 2 then |
| 33 | + resultArray.[i] <- 0 @> |
| 34 | + |
| 35 | + let collectRows = clContext.Compile collectRows |
| 36 | + |
| 37 | + fun (queue: MailboxProcessor<_>) size (vectorIndices: ClArray<int>) (rowOffsets: ClArray<int>) -> |
| 38 | + |
| 39 | + let ndRange = |
| 40 | + Range1D.CreateValid(size * 2 + 1, workGroupSize) |
| 41 | + |
| 42 | + // Last element will contain length of array for gather |
| 43 | + let resultRows = |
| 44 | + clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, size * 2 + 1) |
| 45 | + |
| 46 | + let collectRows = collectRows.GetKernel() |
| 47 | + |
| 48 | + queue.Post( |
| 49 | + Msg.MsgSetArguments(fun () -> collectRows.KernelFunc ndRange size vectorIndices rowOffsets resultRows) |
| 50 | + ) |
| 51 | + |
| 52 | + queue.Post(Msg.CreateRunMsg<_, _>(collectRows)) |
| 53 | + |
| 54 | + resultRows |
| 55 | + |
| 56 | + //For above array compute result offsets |
| 57 | + let private computeOffsetsInplace (clContext: ClContext) workGroupSize = |
| 58 | + |
| 59 | + let prepareOffsets = |
| 60 | + <@ fun (ndRange: Range1D) inputSize (inputArray: ClArray<int>) -> |
| 61 | + |
| 62 | + let i = ndRange.GlobalID0 |
| 63 | + |
| 64 | + if i < inputSize && i % 2 = 0 then |
| 65 | + inputArray.[i + 1] <- inputArray.[i + 1] - inputArray.[i] |
| 66 | + inputArray.[i] <- 0 @> |
| 67 | + |
| 68 | + let sum = |
| 69 | + PrefixSum.standardExcludeInPlace clContext workGroupSize |
| 70 | + |
| 71 | + let prepareOffsets = clContext.Compile prepareOffsets |
| 72 | + |
| 73 | + fun (queue: MailboxProcessor<_>) size (input: ClArray<int>) -> |
| 74 | + |
| 75 | + let ndRange = Range1D.CreateValid(size, workGroupSize) |
| 76 | + |
| 77 | + let prepareOffsets = prepareOffsets.GetKernel() |
| 78 | + |
| 79 | + queue.Post(Msg.MsgSetArguments(fun () -> prepareOffsets.KernelFunc ndRange size input)) |
| 80 | + |
| 81 | + queue.Post(Msg.CreateRunMsg<_, _>(prepareOffsets)) |
| 82 | + |
| 83 | + let resSize = (sum queue input).ToHostAndFree queue |
| 84 | + |
| 85 | + resSize |
| 86 | + |
| 87 | + //Gather rows from the matrix that will be used in multiplication |
| 88 | + let private gather (clContext: ClContext) workGroupSize = |
| 89 | + |
| 90 | + let gather = |
| 91 | + <@ fun (ndRange: Range1D) vectorNNZ (rowOffsets: ClArray<int>) (matrixRowPointers: ClArray<int>) (matrixColumns: ClArray<int>) (matrixValues: ClArray<'a>) (vectorIndices: ClArray<int>) (resultRowsArray: ClArray<int>) (resultIndicesArray: ClArray<int>) (resultValuesArray: ClArray<'a>) -> |
| 92 | + |
| 93 | + //Serial number of row to gather |
| 94 | + let row = ndRange.GlobalID0 |
| 95 | + |
| 96 | + if row < vectorNNZ then |
| 97 | + let offsetIndex = row * 2 + 1 |
| 98 | + let rowOffset = rowOffsets.[offsetIndex] |
| 99 | + |
| 100 | + //vectorIndices.[row] --- actual number of row in matrix |
| 101 | + let actualRow = vectorIndices.[row] |
| 102 | + let matrixIndexOffset = matrixRowPointers.[actualRow] |
| 103 | + |
| 104 | + if rowOffset <> rowOffsets.[offsetIndex + 1] then |
| 105 | + let rowSize = rowOffsets.[offsetIndex + 1] - rowOffset |
| 106 | + |
| 107 | + for i in 0 .. rowSize - 1 do |
| 108 | + resultRowsArray.[i + rowOffset] <- actualRow |
| 109 | + resultIndicesArray.[i + rowOffset] <- matrixColumns.[matrixIndexOffset + i] |
| 110 | + resultValuesArray.[i + rowOffset] <- matrixValues.[matrixIndexOffset + i] @> |
| 111 | + |
| 112 | + let collectRows = collectRows clContext workGroupSize |
| 113 | + |
| 114 | + let computeOffsetsInplace = |
| 115 | + computeOffsetsInplace clContext workGroupSize |
| 116 | + |
| 117 | + let gather = clContext.Compile gather |
| 118 | + |
| 119 | + fun (queue: MailboxProcessor<_>) (matrix: ClMatrix.CSR<'a>) (vector: ClVector.Sparse<'b>) -> |
| 120 | + |
| 121 | + //Collect R[v] and R[v + 1] for each v in vector |
| 122 | + let collectedRows = |
| 123 | + collectRows queue vector.NNZ vector.Indices matrix.RowPointers |
| 124 | + |
| 125 | + //Place R[v + 1] - R[v] in previous array and do prefix sum, computing offsets for gather array |
| 126 | + let gatherArraySize = |
| 127 | + computeOffsetsInplace queue (vector.NNZ * 2 + 1) collectedRows |
| 128 | + |
| 129 | + if gatherArraySize = 0 then |
| 130 | + let resultRows = |
| 131 | + clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, 1) |
| 132 | + |
| 133 | + let resultValues = |
| 134 | + clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, 1) |
| 135 | + |
| 136 | + let resultColumns = |
| 137 | + clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, 1) |
| 138 | + |
| 139 | + resultRows, resultColumns, resultValues, gatherArraySize |
| 140 | + else |
| 141 | + let ndRange = |
| 142 | + Range1D.CreateValid(vector.NNZ, workGroupSize) |
| 143 | + |
| 144 | + let gather = gather.GetKernel() |
| 145 | + |
| 146 | + let resultRows = |
| 147 | + clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, gatherArraySize) |
| 148 | + |
| 149 | + let resultIndices = |
| 150 | + clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, gatherArraySize) |
| 151 | + |
| 152 | + let resultValues = |
| 153 | + clContext.CreateClArrayWithSpecificAllocationMode<'a>(DeviceOnly, gatherArraySize) |
| 154 | + |
| 155 | + if gatherArraySize > 0 then |
| 156 | + queue.Post( |
| 157 | + Msg.MsgSetArguments |
| 158 | + (fun () -> |
| 159 | + gather.KernelFunc |
| 160 | + ndRange |
| 161 | + vector.NNZ |
| 162 | + collectedRows |
| 163 | + matrix.RowPointers |
| 164 | + matrix.Columns |
| 165 | + matrix.Values |
| 166 | + vector.Indices |
| 167 | + resultRows |
| 168 | + resultIndices |
| 169 | + resultValues) |
| 170 | + ) |
| 171 | + |
| 172 | + queue.Post(Msg.CreateRunMsg<_, _>(gather)) |
| 173 | + |
| 174 | + collectedRows.Free queue |
| 175 | + |
| 176 | + resultRows, resultIndices, resultValues, gatherArraySize |
| 177 | + |
| 178 | + |
| 179 | + let private multiplyScalar (clContext: ClContext) (mul: Expr<'a option -> 'b option -> 'c option>) workGroupSize = |
| 180 | + |
| 181 | + let multiply = |
| 182 | + <@ fun (ndRange: Range1D) resultLength vectorLength (rowIndices: ClArray<int>) (matrixValues: ClArray<'a>) (vectorIndices: ClArray<int>) (vectorValues: ClArray<'b>) (resultValues: ClArray<'c option>) -> |
| 183 | + let i = ndRange.GlobalID0 |
| 184 | + |
| 185 | + if i < resultLength then |
| 186 | + let index = rowIndices.[i] |
| 187 | + let matrixValue = matrixValues.[i] |
| 188 | + |
| 189 | + let vectorValue = |
| 190 | + (%Search.Bin.byKey) vectorLength index vectorIndices vectorValues |
| 191 | + |
| 192 | + let res = (%mul) (Some matrixValue) vectorValue |
| 193 | + resultValues.[i] <- res @> |
| 194 | + |
| 195 | + let multiply = clContext.Compile multiply |
| 196 | + |
| 197 | + fun (queue: MailboxProcessor<_>) (columnIndices: ClArray<int>) (matrixValues: ClArray<'a>) (vector: Sparse<'b>) -> |
| 198 | + |
| 199 | + let resultLength = columnIndices.Length |
| 200 | + |
| 201 | + let result = |
| 202 | + clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, resultLength) |
| 203 | + |
| 204 | + let ndRange = |
| 205 | + Range1D.CreateValid(resultLength, workGroupSize) |
| 206 | + |
| 207 | + let multiply = multiply.GetKernel() |
| 208 | + |
| 209 | + queue.Post( |
| 210 | + Msg.MsgSetArguments |
| 211 | + (fun () -> |
| 212 | + multiply.KernelFunc |
| 213 | + ndRange |
| 214 | + resultLength |
| 215 | + vector.NNZ |
| 216 | + columnIndices |
| 217 | + matrixValues |
| 218 | + vector.Indices |
| 219 | + vector.Values |
| 220 | + result) |
| 221 | + ) |
| 222 | + |
| 223 | + queue.Post(Msg.CreateRunMsg<_, _>(multiply)) |
| 224 | + |
| 225 | + result |
| 226 | + |
| 227 | + let runBoolStandard |
| 228 | + (add: Expr<'c option -> 'c option -> 'c option>) |
| 229 | + (mul: Expr<'a option -> 'b option -> 'c option>) |
| 230 | + (clContext: ClContext) |
| 231 | + workGroupSize |
| 232 | + = |
| 233 | + |
| 234 | + let gather = gather clContext workGroupSize |
| 235 | + |
| 236 | + let sort = |
| 237 | + Sort.Radix.standardRunKeysOnly clContext workGroupSize |
| 238 | + |
| 239 | + let removeDuplicates = |
| 240 | + ClArray.removeDuplications clContext workGroupSize |
| 241 | + |
| 242 | + let create = ClArray.create clContext workGroupSize |
| 243 | + |
| 244 | + fun (queue: MailboxProcessor<_>) (matrix: ClMatrix.CSR<'a>) (vector: ClVector.Sparse<'b>) -> |
| 245 | + |
| 246 | + let gatherRows, gatherIndices, gatherValues, gatherLength = gather queue matrix vector |
| 247 | + |
| 248 | + gatherRows.Free queue |
| 249 | + gatherValues.Free queue |
| 250 | + |
| 251 | + if gatherLength <= 0 then |
| 252 | + { Context = clContext |
| 253 | + Indices = gatherIndices |
| 254 | + Values = clContext.CreateClArray [| false |] |
| 255 | + Size = matrix.ColumnCount } |
| 256 | + else |
| 257 | + let sortedIndices = sort queue gatherIndices |
| 258 | + |
| 259 | + let resultIndices = removeDuplicates queue sortedIndices |
| 260 | + |
| 261 | + gatherIndices.Free queue |
| 262 | + sortedIndices.Free queue |
| 263 | + |
| 264 | + { Context = clContext |
| 265 | + Indices = resultIndices |
| 266 | + Values = create queue DeviceOnly resultIndices.Length true |
| 267 | + Size = matrix.ColumnCount } |
0 commit comments