@@ -88,12 +88,14 @@ def decode_and_preprocess(image_bytes: bytes, size: int = 224) -> torch.Tensor:
8888
8989 # To tensor [0..1]
9090 import numpy as np
91+ mean = np .array (IMAGENET_MEAN , dtype = np .float32 )
92+ std = np .array (IMAGENET_STD , dtype = np .float32 )
9193 arr = np .asarray (img ).astype ("float32" ) / 255.0 # H,W,3
9294 # Normalize
93- arr = (arr - IMAGENET_MEAN ) / IMAGENET_STD
95+ arr = (arr - mean ) / std
9496 # HWC -> CHW
95- arr = np .transpose (arr , (2 , 0 , 1 ))
96- return torch .from_numpy (arr ) # float32, shape (3,224,224)
97+ arr = np .transpose (arr , (2 , 0 , 1 )). astype ( "float32" )
98+ return torch .from_numpy (arr ). float () # float32, shape (3,224,224)
9799
98100
99101class RateLimitDoFn (beam .DoFn ):
@@ -174,9 +176,28 @@ def __init__(self, top_k: int, model_name: str):
174176
175177 def process (self , kv : Tuple [str , PredictionResult ]):
176178 image_id , pred = kv
177- logits = pred .inference [
178- "logits" ] # torch.Tensor [B, num_classes] or [num_classes]
179- if isinstance (logits , torch .Tensor ) and logits .ndim == 1 :
179+
180+ # pred can be PredictionResult OR raw inference object.
181+ inference_obj = pred .inference if hasattr (pred , "inference" ) else pred
182+
183+ # inference_obj can be dict {'logits': tensor} OR tensor directly.
184+ if isinstance (inference_obj , dict ):
185+ logits = inference_obj .get ("logits" , None )
186+ if logits is None :
187+ # fallback: try first value if dict shape differs
188+ try :
189+ logits = next (iter (inference_obj .values ()))
190+ except Exception :
191+ logits = None
192+ else :
193+ logits = inference_obj
194+
195+ if not isinstance (logits , torch .Tensor ):
196+ logging .warning ("Unexpected logits type for %s: %s" , image_id , type (logits ))
197+ return
198+
199+ # Ensure shape [1, C]
200+ if logits .ndim == 1 :
180201 logits = logits .unsqueeze (0 )
181202
182203 probs = F .softmax (logits , dim = - 1 ) # [B, C]
@@ -480,7 +501,7 @@ def run(
480501
481502 to_infer = (
482503 preprocessed
483- | 'ToKeyedTensor' >> beam .Map (lambda kv : (kv [0 ], kv [1 ]["tensor" ])))
504+ | 'ToKeyedTensor' >> beam .Map (lambda kv : (kv [0 ], kv [1 ]["tensor" ]. float () )))
484505
485506 predictions = (
486507 to_infer
0 commit comments