Skip to content

Commit cd2df91

Browse files
committed
refactor: flags in ClArray
1 parent ab7d3c3 commit cd2df91

2 files changed

Lines changed: 23 additions & 24 deletions

File tree

src/GraphBLAS-sharp.Backend/Algorithms/BFS.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ module BFS =
2323
SpMV.runTo clContext add mul workGroupSize
2424

2525
let zeroCreate =
26-
ClArray.zeroCreate clContext workGroupSize
26+
ClArray.zeroCreate clContext workGroupSize CPUInterop
2727

2828
let ofList = Vector.ofList clContext workGroupSize
2929

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ open Microsoft.FSharp.Quotations
55
open GraphBLAS.FSharp.Backend.Objects.ClContext
66

77
module ClArray =
8-
let init (initializer: Expr<int -> 'a>) (clContext: ClContext) workGroupSize =
8+
let init (clContext: ClContext) workGroupSize flag (initializer: Expr<int -> 'a>) =
99

1010
let init =
1111
<@ fun (range: Range1D) (outputBuffer: ClArray<'a>) (length: int) ->
@@ -18,9 +18,8 @@ module ClArray =
1818
let program = clContext.Compile(init)
1919

2020
fun (processor: MailboxProcessor<_>) (length: int) ->
21-
// TODO: Выставить нужные флаги
2221
let outputArray =
23-
clContext.CreateClArrayWithFlag(GPUOnly, length)
22+
clContext.CreateClArrayWithFlag(flag, length)
2423

2524
let kernel = program.GetKernel()
2625

@@ -32,7 +31,7 @@ module ClArray =
3231

3332
outputArray
3433

35-
let create (clContext: ClContext) workGroupSize =
34+
let create (clContext: ClContext) workGroupSize flag =
3635

3736
let create =
3837
<@ fun (range: Range1D) (outputBuffer: ClArray<'a>) (length: int) (value: ClCell<'a>) ->
@@ -48,7 +47,7 @@ module ClArray =
4847
let value = clContext.CreateClCell(value)
4948

5049
let outputArray =
51-
clContext.CreateClArrayWithFlag(GPUOnly, length)
50+
clContext.CreateClArrayWithFlag(flag, length)
5251

5352
let kernel = program.GetKernel()
5453

@@ -61,15 +60,15 @@ module ClArray =
6160

6261
outputArray
6362

64-
let zeroCreate (clContext: ClContext) workGroupSize =
63+
let zeroCreate (clContext: ClContext) workGroupSize flag =
6564

66-
let create = create clContext workGroupSize
65+
let create = create clContext workGroupSize flag
6766

68-
fun (processor: MailboxProcessor<_>) (length: int) -> create processor length Unchecked.defaultof<'a>
67+
fun (processor: MailboxProcessor<_>) length ->
68+
create processor length Unchecked.defaultof<'a>
6969

70-
let copy (clContext: ClContext) workGroupSize =
70+
let copy (clContext: ClContext) workGroupSize flag =
7171
let copy =
72-
7372
<@ fun (ndRange: Range1D) (inputArrayBuffer: ClArray<'a>) (outputArrayBuffer: ClArray<'a>) inputArrayLength ->
7473

7574
let i = ndRange.GlobalID0
@@ -84,7 +83,7 @@ module ClArray =
8483
Range1D.CreateValid(inputArray.Length, workGroupSize)
8584

8685
let outputArray =
87-
clContext.CreateClArrayWithFlag(GPUOnly, inputArray.Length)
86+
clContext.CreateClArrayWithFlag(flag, inputArray.Length)
8887

8988
let kernel = program.GetKernel()
9089

@@ -96,7 +95,7 @@ module ClArray =
9695

9796
outputArray
9897

99-
let replicate (clContext: ClContext) =
98+
let replicate (clContext: ClContext) flag =
10099

101100
let replicate =
102101
<@ fun (ndRange: Range1D) (inputArrayBuffer: ClArray<'a>) (outputArrayBuffer: ClArray<'a>) inputArrayLength outputArrayLength ->
@@ -112,7 +111,7 @@ module ClArray =
112111
let outputArrayLength = inputArray.Length * count
113112

114113
let outputArray =
115-
clContext.CreateClArrayWithFlag(GPUOnly, outputArrayLength)
114+
clContext.CreateClArrayWithFlag(flag, outputArrayLength)
116115

117116
let ndRange =
118117
Range1D.CreateValid(outputArray.Length, workGroupSize)
@@ -174,25 +173,25 @@ module ClArray =
174173
///<param name="zero">Zero element for binary operation.</param>
175174
let prefixSumIncludeInplace = PrefixSum.runIncludeInplace
176175

177-
let prefixSumExclude plus (clContext: ClContext) workGroupSize =
176+
let prefixSumExclude plus (clContext: ClContext) workGroupSize flag =
178177

179178
let runExcludeInplace =
180179
prefixSumExcludeInplace plus clContext workGroupSize
181180

182-
let copy = copy clContext workGroupSize
181+
let copy = copy clContext workGroupSize flag
183182

184183
fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) (totalSum: ClCell<'a>) (zero: 'a) ->
185184

186185
let outputArray = copy processor inputArray
187186

188187
runExcludeInplace processor outputArray totalSum zero
189188

190-
let prefixSumInclude plus (clContext: ClContext) workGroupSize =
189+
let prefixSumInclude plus (clContext: ClContext) workGroupSize flag =
191190

192191
let runIncludeInplace =
193192
prefixSumIncludeInplace plus clContext workGroupSize
194193

195-
let copy = copy clContext workGroupSize
194+
let copy = copy clContext workGroupSize flag
196195

197196
fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) (totalSum: ClCell<'a>) (zero: 'a) ->
198197

@@ -206,7 +205,7 @@ module ClArray =
206205
let prefixSumBackwardsIncludeInplace plus =
207206
PrefixSum.runBackwardsIncludeInplace plus
208207

209-
let getUniqueBitmap (clContext: ClContext) =
208+
let getUniqueBitmap (clContext: ClContext) flag =
210209

211210
let getUniqueBitmap =
212211
<@ fun (ndRange: Range1D) (inputArray: ClArray<'a>) inputLength (isUniqueBitmap: ClArray<int>) ->
@@ -229,7 +228,7 @@ module ClArray =
229228
Range1D.CreateValid(inputLength, workGroupSize)
230229

231230
let bitmap =
232-
clContext.CreateClArrayWithFlag(GPUOnly, inputLength)
231+
clContext.CreateClArrayWithFlag(flag, inputLength)
233232

234233
let kernel = kernel.GetKernel()
235234

@@ -248,10 +247,10 @@ module ClArray =
248247
let scatter =
249248
Scatter.runInplace clContext workGroupSize
250249

251-
let getUniqueBitmap = getUniqueBitmap clContext
250+
let getUniqueBitmap = getUniqueBitmap clContext GPUOnly
252251

253252
let prefixSumExclude =
254-
prefixSumExclude <@ (+) @> clContext workGroupSize
253+
prefixSumExclude <@ (+) @> clContext workGroupSize GPUOnly
255254

256255
fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) ->
257256

@@ -308,7 +307,7 @@ module ClArray =
308307

309308
result
310309

311-
let map<'a, 'b> (clContext: ClContext) (workGroupSize: int) (op: Expr<'a -> 'b>) =
310+
let map<'a, 'b> (clContext: ClContext) workGroupSize (op: Expr<'a -> 'b>) flag =
312311

313312
let map =
314313
<@ fun (ndRange: Range1D) (lenght: int) (inputArray: ClArray<'a>) (result: ClArray<'b>) ->
@@ -320,7 +319,7 @@ module ClArray =
320319

321320
let kernel = clContext.Compile map
322321

323-
fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) flag ->
322+
fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) ->
324323

325324
let result =
326325
clContext.CreateClArrayWithFlag(flag, inputArray.Length)

0 commit comments

Comments
 (0)