Skip to content

Commit 1450728

Browse files
committed
add: ClArray.map
1 parent 7e331a5 commit 1450728

3 files changed

Lines changed: 99 additions & 37 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ module ClArray =
2020
fun (processor: MailboxProcessor<_>) (length: int) ->
2121
// TODO: Выставить нужные флаги
2222
let outputArray =
23-
clContext.CreateClArrayWithGPUOnlyFlags(length)
23+
clContext.CreateClArrayWithFlag(GPUOnly, length)
2424

2525
let kernel = program.GetKernel()
2626

@@ -48,7 +48,7 @@ module ClArray =
4848
let value = clContext.CreateClCell(value)
4949

5050
let outputArray =
51-
clContext.CreateClArrayWithGPUOnlyFlags(length)
51+
clContext.CreateClArrayWithFlag(GPUOnly, length)
5252

5353
let kernel = program.GetKernel()
5454

@@ -84,7 +84,7 @@ module ClArray =
8484
Range1D.CreateValid(inputArray.Length, workGroupSize)
8585

8686
let outputArray =
87-
clContext.CreateClArrayWithGPUOnlyFlags inputArray.Length
87+
clContext.CreateClArrayWithFlag(GPUOnly, inputArray.Length)
8888

8989
let kernel = program.GetKernel()
9090

@@ -112,7 +112,7 @@ module ClArray =
112112
let outputArrayLength = inputArray.Length * count
113113

114114
let outputArray =
115-
clContext.CreateClArrayWithGPUOnlyFlags outputArrayLength
115+
clContext.CreateClArrayWithFlag(GPUOnly, outputArrayLength)
116116

117117
let ndRange =
118118
Range1D.CreateValid(outputArray.Length, workGroupSize)
@@ -229,7 +229,7 @@ module ClArray =
229229
Range1D.CreateValid(inputLength, workGroupSize)
230230

231231
let bitmap =
232-
clContext.CreateClArrayWithGPUOnlyFlags inputLength
232+
clContext.CreateClArrayWithFlag(GPUOnly, inputLength)
233233

234234
let kernel = kernel.GetKernel()
235235

@@ -273,7 +273,7 @@ module ClArray =
273273
a.[0]
274274

275275
let outputArray =
276-
clContext.CreateClArrayWithGPUOnlyFlags resultLength
276+
clContext.CreateClArrayWithFlag(GPUOnly, resultLength)
277277

278278
scatter processor positions inputArray outputArray
279279

@@ -307,3 +307,31 @@ module ClArray =
307307
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
308308

309309
result
310+
311+
let map<'a, 'b> (clContext: ClContext) (workGroupSize: int) (op: Expr<'a -> 'b>) =
312+
313+
let map =
314+
<@ fun (ndRange: Range1D) (lenght: int) (inputArray: ClArray<'a>) (result: ClArray<'b>) ->
315+
316+
let gid = ndRange.GlobalID0
317+
318+
if gid < lenght then
319+
result.[gid] <- (%op) inputArray.[gid] @>
320+
321+
let kernel = clContext.Compile map
322+
323+
fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) flag ->
324+
325+
let result =
326+
clContext.CreateClArrayWithFlag(flag, inputArray.Length)
327+
328+
let ndRange =
329+
Range1D.CreateValid(workGroupSize, inputArray.Length)
330+
331+
let kernel = kernel.GetKernel()
332+
333+
processor.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange inputArray.Length inputArray result))
334+
335+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
336+
337+
result

src/GraphBLAS-sharp.Backend/Objects/ClContextExtensions.fs

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,41 @@ namespace GraphBLAS.FSharp.Backend.Objects
33
open Brahma.FSharp
44

55
module ClContext =
6+
type AllocationFlag =
7+
| GPUOnly
8+
| CPUInterop
9+
610
type ClContext with
7-
member this.CreateClArrayWithGPUOnlyFlags(size: int) =
8-
this.CreateClArray(
9-
size,
10-
deviceAccessMode = DeviceAccessMode.ReadWrite,
11-
hostAccessMode = HostAccessMode.NotAccessible,
12-
allocationMode = AllocationMode.Default
13-
)
11+
member this.CreateClArrayWithFlag(mode, (size: int)) =
12+
match mode with
13+
| GPUOnly ->
14+
this.CreateClArray(
15+
size,
16+
deviceAccessMode = DeviceAccessMode.ReadWrite,
17+
hostAccessMode = HostAccessMode.NotAccessible,
18+
allocationMode = AllocationMode.Default
19+
)
20+
| CPUInterop ->
21+
this.CreateClArray(
22+
size,
23+
deviceAccessMode = DeviceAccessMode.ReadWrite,
24+
hostAccessMode = HostAccessMode.ReadWrite,
25+
allocationMode = AllocationMode.Default
26+
)
1427

15-
member this.CreateClArrayWithGPUOnlyFlags(array: 'a []) =
16-
this.CreateClArray(
17-
array,
18-
deviceAccessMode = DeviceAccessMode.ReadWrite,
19-
hostAccessMode = HostAccessMode.NotAccessible,
20-
allocationMode = AllocationMode.CopyHostPtr
21-
)
28+
member this.CreateClArrayWithFlag(mode, (array: 'a [])) =
29+
match mode with
30+
| GPUOnly ->
31+
this.CreateClArray(
32+
array,
33+
deviceAccessMode = DeviceAccessMode.ReadWrite,
34+
hostAccessMode = HostAccessMode.NotAccessible,
35+
allocationMode = AllocationMode.CopyHostPtr
36+
)
37+
| CPUInterop ->
38+
this.CreateClArray(
39+
array,
40+
deviceAccessMode = DeviceAccessMode.ReadWrite,
41+
hostAccessMode = HostAccessMode.ReadWrite,
42+
allocationMode = AllocationMode.CopyHostPtr
43+
)

src/GraphBLAS-sharp.Backend/Vector/Vector.fs

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,45 +17,57 @@ module Vector =
1717
let zeroCreate =
1818
ClArray.zeroCreate clContext workGroupSize
1919

20-
fun (processor: MailboxProcessor<_>) (size: int) (format: VectorFormat) ->
20+
fun (processor: MailboxProcessor<_>) (size: int) (format: VectorFormat) (flag: AllocationFlag) ->
2121
match format with
2222
| Sparse ->
23-
let vector =
23+
ClVector.Sparse
2424
{ Context = clContext
25-
Indices = clContext.CreateClArrayWithGPUOnlyFlags [| 0 |]
26-
Values = clContext.CreateClArrayWithGPUOnlyFlags [| Unchecked.defaultof<'a> |]
25+
Indices = clContext.CreateClArrayWithFlag(flag, [| 0 |])
26+
Values = clContext.CreateClArrayWithFlag(flag, [| Unchecked.defaultof<'a> |])
2727
Size = size }
28-
29-
ClVector.Sparse vector
3028
| Dense -> ClVector.Dense <| zeroCreate processor size
3129

32-
let ofList (clContext: ClContext) =
33-
fun (format: VectorFormat) size (elements: (int * 'a) list) ->
30+
let ofList (clContext: ClContext) (workGroupSize: int) =
31+
32+
let scatter =
33+
Scatter.runInplace clContext workGroupSize
34+
35+
let map =
36+
ClArray.map clContext workGroupSize <@ Some @>
37+
38+
fun (processor: MailboxProcessor<_>) (format: VectorFormat) flag size (elements: (int * 'a) list) ->
3439
let indices, values =
3540
elements
3641
|> Array.ofList
3742
|> Array.sortBy fst
3843
|> Array.unzip
3944

45+
let indices =
46+
clContext.CreateClArrayWithFlag(flag, indices)
47+
48+
let values =
49+
clContext.CreateClArrayWithFlag(flag, values)
50+
4051
match format with
4152
| Sparse ->
42-
let indices = clContext.CreateClArray indices
43-
let values = clContext.CreateClArray values
44-
4553
{ Context = clContext
4654
Indices = indices
4755
Values = values
4856
Size = size }
49-
5057
|> ClVector.Sparse
5158
| Dense ->
52-
let res = Array.zeroCreate size
59+
let mappedValues = map processor values flag
5360

54-
for i in 0 .. indices.Length - 1 do
55-
res.[indices.[i]] <- Some(values.[i])
61+
let result =
62+
clContext.CreateClArrayWithFlag(flag, size)
5663

57-
ClVector.Dense
58-
<| clContext.CreateClArrayWithGPUOnlyFlags res
64+
scatter processor indices mappedValues result
65+
66+
processor.Post(Msg.CreateFreeMsg(mappedValues))
67+
processor.Post(Msg.CreateFreeMsg(indices))
68+
processor.Post(Msg.CreateFreeMsg(values))
69+
70+
ClVector.Dense result
5971

6072
let copy (clContext: ClContext) (workGroupSize: int) =
6173
let copy = ClArray.copy clContext workGroupSize

0 commit comments

Comments
 (0)