Skip to content

Commit 6cab3e6

Browse files
committed
eWiseAdd with AtLeastOne
1 parent 22bd23d commit 6cab3e6

5 files changed

Lines changed: 294 additions & 1 deletion

File tree

src/GraphBLAS-sharp.Backend/Common/StandardOperations.fs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
namespace GraphBLAS.FSharp.Backend.Common
22

3+
open GraphBLAS.FSharp.Backend.Common
4+
5+
type AtLeastOne<'a, 'b when 'a: struct and 'b: struct> =
6+
| Both of 'a * 'b
7+
| Left of 'a
8+
| Right of 'b
9+
310
module StandardOperations =
411
let boolSum =
512
<@ fun (_: bool option) (_: bool option) -> Some true @>
@@ -51,3 +58,50 @@ module StandardOperations =
5158
| None, None -> ()
5259

5360
if res = 0f then None else (Some res) @>
61+
62+
let boolSum2 =
63+
<@ fun (_: AtLeastOne<bool, bool>) -> Some true @>
64+
65+
let intSum2 =
66+
<@ fun (values: AtLeastOne<int, int>) ->
67+
let mutable res = 0
68+
69+
match values with
70+
| Both (f, s) -> res <- f + s
71+
| Left f -> res <- f
72+
| Right s -> res <- s
73+
74+
if res = 0 then None else (Some res) @>
75+
76+
let byteSum2 =
77+
<@ fun (values: AtLeastOne<byte, byte>) ->
78+
let mutable res = 0uy
79+
80+
match values with
81+
| Both (f, s) -> res <- f + s
82+
| Left f -> res <- f
83+
| Right s -> res <- s
84+
85+
if res = 0uy then None else (Some res) @>
86+
87+
let floatSum2 =
88+
<@ fun (values: AtLeastOne<float, float>) ->
89+
let mutable res = 0.0
90+
91+
match values with
92+
| Both (f, s) -> res <- f + s
93+
| Left f -> res <- f
94+
| Right s -> res <- s
95+
96+
if res = 0.0 then None else (Some res) @>
97+
98+
let float32Sum2 =
99+
<@ fun (values: AtLeastOne<float32, float32>) ->
100+
let mutable res = 0f
101+
102+
match values with
103+
| Both (f, s) -> res <- f + s
104+
| Left f -> res <- f
105+
| Right s -> res <- s
106+
107+
if res = 0f then None else (Some res) @>

src/GraphBLAS-sharp.Backend/Matrix/COOMatrix/COOMatrix.fs

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
namespace GraphBLAS.FSharp.Backend
22

33
open Brahma.FSharp
4+
open GraphBLAS.FSharp.Backend
5+
open GraphBLAS.FSharp.Backend.Common
46
open Microsoft.FSharp.Quotations
57

68
module COOMatrix =
@@ -659,3 +661,135 @@ module COOMatrix =
659661
RowPointers = compressedRows
660662
Columns = matrix.Columns
661663
Values = matrix.Values }
664+
665+
let private preparePositions2<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct and 'c: equality>
666+
(clContext: ClContext)
667+
(opAdd: Expr<AtLeastOne<'a, 'b> -> 'c option>)
668+
workGroupSize
669+
=
670+
671+
let preparePositions =
672+
<@ fun (ndRange: Range1D) length (allRowsBuffer: ClArray<int>) (allColumnsBuffer: ClArray<int>) (leftValuesBuffer: ClArray<'a>) (rightValuesBuffer: ClArray<'b>) (allValuesBuffer: ClArray<'c>) (rawPositionsBuffer: ClArray<int>) (isLeftBitmap: ClArray<int>) ->
673+
674+
let i = ndRange.GlobalID0
675+
676+
if (i < length - 1
677+
&& allRowsBuffer.[i] = allRowsBuffer.[i + 1]
678+
&& allColumnsBuffer.[i] = allColumnsBuffer.[i + 1]) then
679+
rawPositionsBuffer.[i] <- 0
680+
681+
match (%opAdd) (Both(leftValuesBuffer.[i + 1], rightValuesBuffer.[i])) with
682+
| Some v ->
683+
allValuesBuffer.[i + 1] <- v
684+
rawPositionsBuffer.[i + 1] <- 1
685+
| None -> rawPositionsBuffer.[i + 1] <- 0
686+
else if (i > 0
687+
&& i < length
688+
&& (allRowsBuffer.[i] <> allRowsBuffer.[i - 1]
689+
|| allColumnsBuffer.[i] <> allColumnsBuffer.[i - 1]))
690+
|| i = 0 then
691+
if isLeftBitmap.[i] = 1 then
692+
match (%opAdd) (Left leftValuesBuffer.[i]) with
693+
| Some v ->
694+
allValuesBuffer.[i] <- v
695+
rawPositionsBuffer.[i] <- 1
696+
| None -> rawPositionsBuffer.[i] <- 0
697+
else
698+
match (%opAdd) (Right rightValuesBuffer.[i]) with
699+
| Some v ->
700+
allValuesBuffer.[i] <- v
701+
rawPositionsBuffer.[i] <- 1
702+
| None -> rawPositionsBuffer.[i] <- 0 @>
703+
704+
let kernel =
705+
clContext.Compile(preparePositions)
706+
707+
fun (processor: MailboxProcessor<_>) (allRows: ClArray<int>) (allColumns: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (isLeft: ClArray<int>) ->
708+
let length = leftValues.Length
709+
710+
let ndRange =
711+
Range1D.CreateValid(length, workGroupSize)
712+
713+
let rawPositionsGpu =
714+
clContext.CreateClArray<int>(
715+
length,
716+
hostAccessMode = HostAccessMode.NotAccessible,
717+
allocationMode = AllocationMode.Default
718+
)
719+
720+
let allValues =
721+
clContext.CreateClArray<'c>(
722+
length,
723+
hostAccessMode = HostAccessMode.NotAccessible,
724+
allocationMode = AllocationMode.Default
725+
)
726+
727+
let kernel = kernel.GetKernel()
728+
729+
processor.Post(
730+
Msg.MsgSetArguments
731+
(fun () ->
732+
kernel.KernelFunc
733+
ndRange
734+
length
735+
allRows
736+
allColumns
737+
leftValues
738+
rightValues
739+
allValues
740+
rawPositionsGpu
741+
isLeft)
742+
)
743+
744+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
745+
rawPositionsGpu, allValues
746+
747+
///<param name="clContext">.</param>
748+
///<param name="opAdd">.</param>
749+
///<param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
750+
let eWiseAdd2<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct and 'c: equality>
751+
(clContext: ClContext)
752+
(opAdd: Expr<AtLeastOne<'a, 'b> -> 'c option>)
753+
workGroupSize
754+
=
755+
756+
let merge = merge clContext workGroupSize
757+
758+
let preparePositions =
759+
preparePositions2 clContext opAdd workGroupSize
760+
761+
let setPositions = setPositions<'c> clContext workGroupSize
762+
763+
fun (queue: MailboxProcessor<_>) (matrixLeft: COOMatrix<'a>) (matrixRight: COOMatrix<'b>) ->
764+
765+
let allRows, allColumns, leftMergedValues, rightMergedValues, isLeft =
766+
merge
767+
queue
768+
matrixLeft.Rows
769+
matrixLeft.Columns
770+
matrixLeft.Values
771+
matrixRight.Rows
772+
matrixRight.Columns
773+
matrixRight.Values
774+
775+
let rawPositions, allValues =
776+
preparePositions queue allRows allColumns leftMergedValues rightMergedValues isLeft
777+
778+
queue.Post(Msg.CreateFreeMsg<_>(leftMergedValues))
779+
queue.Post(Msg.CreateFreeMsg<_>(rightMergedValues))
780+
781+
let resultRows, resultColumns, resultValues, resultLength =
782+
setPositions queue allRows allColumns allValues rawPositions
783+
784+
queue.Post(Msg.CreateFreeMsg<_>(isLeft))
785+
queue.Post(Msg.CreateFreeMsg<_>(rawPositions))
786+
queue.Post(Msg.CreateFreeMsg<_>(allRows))
787+
queue.Post(Msg.CreateFreeMsg<_>(allColumns))
788+
queue.Post(Msg.CreateFreeMsg<_>(allValues))
789+
790+
{ Context = clContext
791+
RowCount = matrixLeft.RowCount
792+
ColumnCount = matrixLeft.ColumnCount
793+
Rows = resultRows
794+
Columns = resultColumns
795+
Values = resultValues }

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/CSRMatrix.fs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ namespace GraphBLAS.FSharp.Backend
22

33
open Brahma.FSharp
44
open GraphBLAS.FSharp.Backend
5+
open GraphBLAS.FSharp.Backend.Common
56
open Microsoft.FSharp.Quotations
67

78
module CSRMatrix =
@@ -117,3 +118,29 @@ module CSRMatrix =
117118
processor.Post(Msg.CreateFreeMsg(m3COO.Rows))
118119

119120
m3
121+
122+
let eWiseAdd2 (clContext: ClContext) (opAdd: Expr<AtLeastOne<'a, 'b> -> 'c option>) workGroupSize =
123+
124+
let toCOOInplaceLeft = toCOOInplace clContext workGroupSize
125+
let toCOOInplaceRight = toCOOInplace clContext workGroupSize
126+
127+
let eWiseCOO =
128+
COOMatrix.eWiseAdd2 clContext opAdd workGroupSize
129+
130+
let toCSRInplace =
131+
COOMatrix.toCSRInplace clContext workGroupSize
132+
133+
fun (processor: MailboxProcessor<_>) (m1: CSRMatrix<'a>) (m2: CSRMatrix<'b>) ->
134+
135+
let m1COO = toCOOInplaceLeft processor m1
136+
let m2COO = toCOOInplaceRight processor m2
137+
138+
let m3COO = eWiseCOO processor m1COO m2COO
139+
140+
processor.Post(Msg.CreateFreeMsg(m1COO.Rows))
141+
processor.Post(Msg.CreateFreeMsg(m2COO.Rows))
142+
143+
let m3 = toCSRInplace processor m3COO
144+
processor.Post(Msg.CreateFreeMsg(m3COO.Rows))
145+
146+
m3

src/GraphBLAS-sharp.Backend/Matrix/Matrix.fs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
namespace GraphBLAS.FSharp.Backend
22

33
open Brahma.FSharp
4+
open GraphBLAS.FSharp.Backend
45
open Microsoft.FSharp.Quotations
5-
open OpenCL.Net
6+
open GraphBLAS.FSharp.Backend.Common
67

78
module Matrix =
89
let copy (clContext: ClContext) =
@@ -65,3 +66,16 @@ module Matrix =
6566
| MatrixCOO m1, MatrixCOO m2 -> COOeWiseAdd processor m1 m2 |> MatrixCOO
6667
| MatrixCSR m1, MatrixCSR m2 -> CSReWiseAdd processor m1 m2 |> MatrixCSR
6768
| _ -> failwith "Matrix formats are not matching"
69+
70+
let eWiseAdd2 (clContext: ClContext) (opAdd: Expr<AtLeastOne<'a, 'b> -> 'c option>) workGroupSize =
71+
let COOeWiseAdd =
72+
COOMatrix.eWiseAdd2 clContext opAdd workGroupSize
73+
74+
let CSReWiseAdd =
75+
CSRMatrix.eWiseAdd2 clContext opAdd workGroupSize
76+
77+
fun (processor: MailboxProcessor<_>) matrix1 matrix2 ->
78+
match matrix1, matrix2 with
79+
| MatrixCOO m1, MatrixCOO m2 -> COOeWiseAdd processor m1 m2 |> MatrixCOO
80+
| MatrixCSR m1, MatrixCSR m2 -> CSReWiseAdd processor m1 m2 |> MatrixCSR
81+
| _ -> failwith "Matrix formats are not matching"

tests/GraphBLAS-sharp.Tests/BackendCommonTests/MatrixEwiseAddTests.fs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,67 @@ let tests =
147147
)
148148
|> List.collect testFixtures
149149
|> testList "Backend.Matrix.eWiseAdd tests"
150+
151+
let testFixtures2 case =
152+
[ let config = defaultConfig
153+
let wgSize = 256
154+
155+
let getCorrectnessTestName datatype =
156+
sprintf "Correctness on %s, %A" datatype case
157+
158+
let context = case.ClContext.ClContext
159+
let q = case.ClContext.Queue
160+
q.Error.Add(fun e -> failwithf "%A" e)
161+
162+
let boolAdd =
163+
Matrix.eWiseAdd2 context boolSum2 wgSize
164+
165+
let boolToCOO = Matrix.toCOO context wgSize
166+
167+
case
168+
|> correctnessGenericTest false (||) boolAdd boolToCOO (=) q
169+
|> testPropertyWithConfig config (getCorrectnessTestName "bool")
170+
171+
let intAdd =
172+
Matrix.eWiseAdd2 context intSum2 wgSize
173+
174+
let intToCOO = Matrix.toCOO context wgSize
175+
176+
case
177+
|> correctnessGenericTest 0 (+) intAdd intToCOO (=) q
178+
|> testPropertyWithConfig config (getCorrectnessTestName "int")
179+
180+
let floatAdd =
181+
Matrix.eWiseAdd2 context floatSum2 wgSize
182+
183+
let floatToCOO = Matrix.toCOO context wgSize
184+
185+
case
186+
|> correctnessGenericTest 0.0 (+) floatAdd floatToCOO (fun x y -> abs (x - y) < Accuracy.medium.absolute) q
187+
|> testPropertyWithConfig config (getCorrectnessTestName "float")
188+
189+
let byteAdd =
190+
Matrix.eWiseAdd2 context byteSum2 wgSize
191+
192+
let byteToCOO = Matrix.toCOO context wgSize
193+
194+
case
195+
|> correctnessGenericTest 0uy (+) byteAdd byteToCOO (=) q
196+
|> testPropertyWithConfig config (getCorrectnessTestName "byte") ]
197+
198+
let tests2 =
199+
testCases
200+
|> List.filter
201+
(fun case ->
202+
let mutable e = ErrorCode.Unknown
203+
let device = case.ClContext.ClContext.ClDevice.Device
204+
205+
let deviceType =
206+
Cl
207+
.GetDeviceInfo(device, DeviceInfo.Type, &e)
208+
.CastTo<DeviceType>()
209+
210+
deviceType = DeviceType.Gpu
211+
)
212+
|> List.collect testFixtures2
213+
|> testList "Backend.Matrix.eWiseAdd2 tests"

0 commit comments

Comments
 (0)