@@ -2901,9 +2901,8 @@ def test_softmax_f32_f32(self) -> None:
29012901 torch .ones ((12 , 4 ), dtype = torch .int8 ), # weights_hidden: 12x4 (3*4 x 4)
29022902 0.1 , # w_h_scale
29032903 torch .zeros (12 , dtype = torch .int8 ), # bias_inputs: 12
2904- 0.1 , # b_i_scale
2904+ 0.1 , # b_scale
29052905 torch .zeros (12 , dtype = torch .int8 ), # bias_hidden: 12
2906- 0.1 , # b_h_scale
29072906 ),
29082907 (
29092908 "invalid_batch_size_2" ,
@@ -2918,9 +2917,8 @@ def test_softmax_f32_f32(self) -> None:
29182917 torch .ones ((12 , 4 ), dtype = torch .int8 ), # weights_hidden: 12x4
29192918 0.1 , # w_h_scale
29202919 torch .zeros (12 , dtype = torch .int8 ), # bias_inputs: 12
2921- 0.1 , # b_i_scale
2920+ 0.1 , # b_scale
29222921 torch .zeros (12 , dtype = torch .int8 ), # bias_hidden: 12
2923- 0.1 , # b_h_scale
29242922 ),
29252923 (
29262924 "non_zero_biases" ,
@@ -2933,11 +2931,10 @@ def test_softmax_f32_f32(self) -> None:
29332931 torch .tensor (
29342932 [1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 , 3 , 3 , 3 , 3 ], dtype = torch .int8
29352933 ), # bias_inputs: 12
2936- 0.1 , # b_i_scale
2934+ 0.1 , # b_scale
29372935 torch .tensor (
29382936 [1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 , 3 , 3 , 3 , 3 ], dtype = torch .int8
29392937 ), # bias_hidden: 12
2940- 0.1 , # b_h_scale
29412938 ),
29422939 (
29432940 "negative_weights" ,
@@ -2954,9 +2951,8 @@ def test_softmax_f32_f32(self) -> None:
29542951 ), # weights_hidden: 12x4 (alternating pattern)
29552952 0.1 , # w_h_scale
29562953 torch .zeros (12 , dtype = torch .int8 ), # bias_inputs: 12
2957- 0.1 , # b_i_scale
2954+ 0.1 , # b_scale
29582955 torch .zeros (12 , dtype = torch .int8 ), # bias_hidden: 12
2959- 0.1 , # b_h_scale
29602956 ),
29612957 (
29622958 "hidden_dim_8" ,
@@ -2969,9 +2965,8 @@ def test_softmax_f32_f32(self) -> None:
29692965 torch .ones ((24 , 8 ), dtype = torch .int8 ), # weights_hidden: 24x8 (3*8 x 8)
29702966 0.1 , # w_h_scale
29712967 torch .zeros (24 , dtype = torch .int8 ), # bias_inputs: 24
2972- 0.1 , # b_i_scale
2968+ 0.1 , # b_scale
29732969 torch .zeros (24 , dtype = torch .int8 ), # bias_hidden: 24
2974- 0.1 , # b_h_scale
29752970 ),
29762971 ]
29772972 )
@@ -2985,9 +2980,8 @@ def test_quantized_w8a32_gru(
29852980 weights_hidden : torch .Tensor ,
29862981 w_h_scale : float ,
29872982 bias_inputs : torch .Tensor ,
2988- b_i_scale : float ,
2983+ b_scale : float ,
29892984 bias_hidden : torch .Tensor ,
2990- b_h_scale : float ,
29912985 ) -> None :
29922986
29932987 if name == "invalid_batch_size_2" :
@@ -3000,9 +2994,8 @@ def test_quantized_w8a32_gru(
30002994 weights_hidden ,
30012995 w_h_scale ,
30022996 bias_inputs ,
3003- b_i_scale ,
2997+ b_scale ,
30042998 bias_hidden ,
3005- b_h_scale ,
30062999 )
30073000 self .assertIn (
30083001 "Leading dimension 0 of hidden state must be 1" , str (context .exception )
@@ -3017,9 +3010,8 @@ def test_quantized_w8a32_gru(
30173010 weights_hidden ,
30183011 w_h_scale ,
30193012 bias_inputs ,
3020- b_i_scale ,
3013+ b_scale ,
30213014 bias_hidden ,
3022- b_h_scale ,
30233015 )
30243016
30253017 # Verify output properties
@@ -3028,10 +3020,11 @@ def test_quantized_w8a32_gru(
30283020 torch .float32 ,
30293021 f"Output dtype should be float32 in { name } " ,
30303022 )
3023+ expected_shape = (2 , inputs .shape [0 ], inputs .shape [1 ], hidden .shape [- 1 ])
30313024 self .assertEqual (
30323025 output .shape ,
3033- ( 2 , * hidden . shape ) ,
3034- f"Output shape should match { ( 2 , * hidden . shape ) } in { name } " ,
3026+ expected_shape ,
3027+ f"Output shape should match { expected_shape } in { name } " ,
30353028 )
30363029 assert isinstance (output , torch .Tensor )
30373030
@@ -3064,7 +3057,6 @@ def test_quantized_w8a32_gru_invalid_hidden_dim(self) -> None:
30643057 bias_inputs ,
30653058 0.1 ,
30663059 bias_hidden ,
3067- 0.1 ,
30683060 )
30693061
30703062 self .assertIn (
0 commit comments