Skip to content

Commit 830b2f9

Browse files
committed
wip: ClArray.pairwise
1 parent 1f273d1 commit 830b2f9

6 files changed

Lines changed: 125 additions & 55 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -621,35 +621,36 @@ module ClArray =
621621
let fill (clContext: ClContext) workGroupSize =
622622

623623
let fill =
624-
<@ fun (ndRange: Range1D) firstPosition endPosition (value: ClCell<'a>) (targetArray: ClArray<'a>) ->
624+
<@ fun (ndRange: Range1D) firstPosition count (value: ClCell<'a>) (targetArray: ClArray<'a>) ->
625625

626626
let gid = ndRange.GlobalID0
627627
let writePosition = gid + firstPosition
628628

629-
if writePosition < endPosition then
630-
629+
if gid < count then
631630
targetArray.[writePosition] <- value.Value @>
632631

633632
let kernel = clContext.Compile fill
634633

635634
fun (processor: MailboxProcessor<_>) value firstPosition count (targetArray: ClArray<'a>) ->
636-
if firstPosition + count > targetArray.Length then
637-
failwith ""
635+
if count = 0 then ()
636+
else
637+
if firstPosition + count > targetArray.Length then
638+
failwith ""
638639

639-
if firstPosition < 0 then failwith ""
640-
if count <= 0 then failwith "" // TODO()
640+
if firstPosition < 0 then failwith ""
641+
if count < 0 then failwith "" // TODO()
641642

642-
let ndRange =
643-
Range1D.CreateValid(count, workGroupSize)
643+
let ndRange =
644+
Range1D.CreateValid(count, workGroupSize)
644645

645-
let kernel = kernel.GetKernel()
646+
let kernel = kernel.GetKernel()
646647

647-
processor.Post(
648-
Msg.MsgSetArguments
649-
(fun () -> kernel.KernelFunc ndRange firstPosition (firstPosition + count) value targetArray)
650-
)
648+
processor.Post(
649+
Msg.MsgSetArguments
650+
(fun () -> kernel.KernelFunc ndRange firstPosition count value targetArray)
651+
)
651652

652-
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
653+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
653654

654655
let pairwise (clContext: ClContext) workGroupSize =
655656

@@ -659,18 +660,26 @@ module ClArray =
659660
let incGather =
660661
Gather.runInit Map.inc clContext workGroupSize
661662

663+
let map = map2 clContext workGroupSize <@ fun first second -> (first, second) @>
664+
662665
fun (processor: MailboxProcessor<_>) allocationMode (values: ClArray<'a>) ->
666+
if values.Length > 1 then
667+
let resultLength = values.Length - 1
663668

664-
let resultLength = values.Length - 1
669+
let firstItems =
670+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
665671

666-
let firstItems =
667-
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
672+
idGather processor values firstItems
668673

669-
idGather processor values firstItems
674+
let secondItems =
675+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
670676

671-
let secondItems =
672-
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
677+
incGather processor values secondItems
678+
679+
let result = map processor allocationMode firstItems secondItems
673680

674-
incGather processor values secondItems
681+
firstItems.Free processor
682+
secondItems.Free processor
675683

676-
firstItems, secondItems
684+
Some result
685+
else None

src/GraphBLAS-sharp.Backend/Matrix/SpGeMM/Expand.fs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,16 @@ module Expand =
2626

2727
fun (processor: MailboxProcessor<_>) (matrix: ClMatrix.CSR<'b>) ->
2828

29-
let firstPointers, secondPointers =
30-
pairwise processor DeviceOnly matrix.RowPointers
31-
32-
let rowsLength = subtract processor DeviceOnly secondPointers firstPointers
33-
34-
firstPointers.Free processor
35-
secondPointers.Free processor
36-
37-
rowsLength
29+
// let firstPointers, secondPointers =
30+
// pairwise processor DeviceOnly matrix.RowPointers
31+
32+
// let rowsLength = subtract processor DeviceOnly secondPointers firstPointers
33+
//
34+
// firstPointers.Free processor
35+
// secondPointers.Free processor
36+
//
37+
// rowsLength
38+
clContext.CreateClArray [| |]
3839

3940
let getSegmentPointers (clContext: ClContext) workGroupSize =
4041

tests/GraphBLAS-sharp.Tests/Common/ClArray/Fill.fs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,28 @@ let processor = Context.defaultContext.Queue
1414

1515
let config =
1616
{ Utils.defaultConfig with
17-
arbitrary = [ typeof<Generators.ArrayAndChunkPositions> ] }
18-
19-
let makeTest<'a> isEqual testFun (value: 'a, targetIndex, count, target: 'a [] ) =
17+
arbitrary = [ typeof<Generators.Fill> ] }
2018

19+
let makeTest<'a> isEqual testFun (value: 'a, targetPosition, count, target: 'a [] ) =
2120
if target.Length > 0 then
2221

2322
let clTarget = context.CreateClArray target
2423
let clValue = context.CreateClCell value
2524

26-
testFun processor clValue 0 0 clTarget
25+
testFun processor clValue targetPosition count clTarget
2726

2827
// release
2928
let actual = clTarget.ToHostAndFree processor
3029

3130
// write to target
32-
Array.fill target targetIndex count value
31+
Array.fill target targetPosition count value
3332

3433
"Results must be the same"
3534
|> Utils.compareArrays isEqual actual target
3635

3736
let createTest<'a> isEqual =
3837
ClArray.fill context Utils.defaultWorkGroupSize
39-
|> makeTest isEqual
38+
|> makeTest<'a> isEqual
4039
|> testPropertyWithConfig config $"test on %A{typeof<'a>}"
4140

4241
let tests =
@@ -47,3 +46,4 @@ let tests =
4746

4847
createTest<float32> (=)
4948
createTest<bool> (=) ]
49+
|> testList "Fill"
Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module GraphBLAS.FSharp.Tests.Common.Backend.ClArray.Pairwise
1+
module GraphBLAS.FSharp.Tests.Backend.Common.ClArray.Pairwise
22

33
open Expecto
44
open Brahma.FSharp
@@ -14,32 +14,31 @@ let processor = Context.defaultContext.Queue
1414

1515
let config =
1616
{ Utils.defaultConfig with
17-
arbitrary = [ typeof<Generators.ArrayAndChunkPositions> ] }
17+
arbitrary = [ typeof<Generators.BufferCompatibleArray> ] }
1818

1919
let makeTest<'a> isEqual testFun (array: 'a [] ) =
2020
if array.Length > 0 then
2121

2222
let clArray = context.CreateClArray array
2323

24-
let (clFirstActual: ClArray<_>), (clSecondActual: ClArray<_>)
25-
= testFun processor HostInterop clArray
24+
testFun processor HostInterop clArray
25+
|> Option.bind (fun (clFirstActual: ClArray<_>, clSecondActual: ClArray<_>) ->
26+
let firstActual = clFirstActual.ToHostAndFree processor
27+
let secondActual = clSecondActual.ToHostAndFree processor
2628

27-
let firstActual = clFirstActual.ToHostAndFree processor
28-
let secondActual = clSecondActual.ToHostAndFree processor
29+
let firstExpected, secondExpected = Array.pairwise array |> Array.unzip
2930

30-
let firstExpected, secondExpected =
31-
Array.pairwise array
32-
|> Array.unzip
31+
"First results must be the same"
32+
|> Utils.compareArrays isEqual firstActual firstExpected
3333

34-
"First results must be the same"
35-
|> Utils.compareArrays isEqual firstActual firstExpected
36-
37-
"Second results must be the same"
38-
|> Utils.compareArrays isEqual secondActual secondExpected
34+
"Second results must be the same"
35+
|> Utils.compareArrays isEqual secondActual secondExpected
36+
None)
37+
|> ignore
3938

4039
let createTest<'a> isEqual =
4140
ClArray.pairwise context Utils.defaultWorkGroupSize
42-
|> makeTest isEqual
41+
|> makeTest<'a> isEqual
4342
|> testPropertyWithConfig config $"test on %A{typeof<'a>}"
4443

4544
let tests =
@@ -50,3 +49,4 @@ let tests =
5049

5150
createTest<float32> (=)
5251
createTest<bool> (=) ]
52+
|> testList "Pairwise"

tests/GraphBLAS-sharp.Tests/Generators.fs

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,7 @@ module Generators =
960960
type AssignArray() =
961961
static let pairOfVectorsOfEqualSize (valuesGenerator: Gen<'a>) =
962962
gen {
963-
let! targetArrayLength = Gen.sized <| fun size -> Gen.choose (2, size)
963+
let! targetArrayLength = Gen.sized <| fun size -> Gen.choose (2, size + 2)
964964

965965
let! targetArray = Gen.arrayOfLength targetArrayLength valuesGenerator
966966

@@ -1016,3 +1016,63 @@ module Generators =
10161016
static member BoolType() =
10171017
pairOfVectorsOfEqualSize <| Arb.generate<bool>
10181018
|> Arb.fromGen
1019+
1020+
type Fill() =
1021+
static let pairOfVectorsOfEqualSize (valuesGenerator: Gen<'a>) =
1022+
gen {
1023+
let! value = valuesGenerator
1024+
1025+
let! targetArrayLength = Gen.sized <| fun size -> Gen.choose(1, size + 1)
1026+
1027+
let! targetArray = Gen.arrayOfLength targetArrayLength valuesGenerator
1028+
1029+
let! targetPosition = Gen.choose (0, targetArrayLength)
1030+
1031+
let! targetCount = Gen.choose(0, targetArrayLength - targetPosition)
1032+
1033+
return (value, targetPosition, targetCount, targetArray)
1034+
}
1035+
1036+
static member IntType() =
1037+
pairOfVectorsOfEqualSize <| Arb.generate<int>
1038+
|> Arb.fromGen
1039+
1040+
static member FloatType() =
1041+
pairOfVectorsOfEqualSize
1042+
<| (Arb.Default.NormalFloat()
1043+
|> Arb.toGen
1044+
|> Gen.map float)
1045+
|> Arb.fromGen
1046+
1047+
static member Float32Type() =
1048+
pairOfVectorsOfEqualSize
1049+
<| (normalFloat32Generator <| System.Random())
1050+
|> Arb.fromGen
1051+
1052+
static member SByteType() =
1053+
pairOfVectorsOfEqualSize <| Arb.generate<sbyte>
1054+
|> Arb.fromGen
1055+
1056+
static member ByteType() =
1057+
pairOfVectorsOfEqualSize <| Arb.generate<byte>
1058+
|> Arb.fromGen
1059+
1060+
static member Int16Type() =
1061+
pairOfVectorsOfEqualSize <| Arb.generate<int16>
1062+
|> Arb.fromGen
1063+
1064+
static member UInt16Type() =
1065+
pairOfVectorsOfEqualSize <| Arb.generate<uint16>
1066+
|> Arb.fromGen
1067+
1068+
static member Int32Type() =
1069+
pairOfVectorsOfEqualSize <| Arb.generate<int32>
1070+
|> Arb.fromGen
1071+
1072+
static member UInt32Type() =
1073+
pairOfVectorsOfEqualSize <| Arb.generate<uint32>
1074+
|> Arb.fromGen
1075+
1076+
static member BoolType() =
1077+
pairOfVectorsOfEqualSize <| Arb.generate<bool>
1078+
|> Arb.fromGen

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,5 +95,5 @@ open GraphBLAS.FSharp.Tests
9595

9696
[<EntryPoint>]
9797
let main argv =
98-
testList "lol" [ Common.ClArray.Concat.tests ] |> testSequenced
98+
testList "lol" [ Common.ClArray.Pairwise.tests ] |> testSequenced
9999
|> runTestsWithCLIArgs [] argv

0 commit comments

Comments
 (0)