Skip to content

Commit 126d30b

Browse files
committed
Fix float for imagenet rightfit
1 parent 9dca475 commit 126d30b

1 file changed

Lines changed: 28 additions & 7 deletions

File tree

sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,14 @@ def decode_and_preprocess(image_bytes: bytes, size: int = 224) -> torch.Tensor:
8888

8989
# To tensor [0..1]
9090
import numpy as np
91+
mean = np.array(IMAGENET_MEAN, dtype=np.float32)
92+
std = np.array(IMAGENET_STD, dtype=np.float32)
9193
arr = np.asarray(img).astype("float32") / 255.0 # H,W,3
9294
# Normalize
93-
arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
95+
arr = (arr - mean) / std
9496
# HWC -> CHW
95-
arr = np.transpose(arr, (2, 0, 1))
96-
return torch.from_numpy(arr) # float32, shape (3,224,224)
97+
arr = np.transpose(arr, (2, 0, 1)).astype("float32")
98+
return torch.from_numpy(arr).float() # float32, shape (3,224,224)
9799

98100

99101
class RateLimitDoFn(beam.DoFn):
@@ -174,9 +176,28 @@ def __init__(self, top_k: int, model_name: str):
174176

175177
def process(self, kv: Tuple[str, PredictionResult]):
176178
image_id, pred = kv
177-
logits = pred.inference[
178-
"logits"] # torch.Tensor [B, num_classes] or [num_classes]
179-
if isinstance(logits, torch.Tensor) and logits.ndim == 1:
179+
180+
# pred can be PredictionResult OR raw inference object.
181+
inference_obj = pred.inference if hasattr(pred, "inference") else pred
182+
183+
# inference_obj can be dict {'logits': tensor} OR tensor directly.
184+
if isinstance(inference_obj, dict):
185+
logits = inference_obj.get("logits", None)
186+
if logits is None:
187+
# fallback: try first value if dict shape differs
188+
try:
189+
logits = next(iter(inference_obj.values()))
190+
except Exception:
191+
logits = None
192+
else:
193+
logits = inference_obj
194+
195+
if not isinstance(logits, torch.Tensor):
196+
logging.warning("Unexpected logits type for %s: %s", image_id, type(logits))
197+
return
198+
199+
# Ensure shape [1, C]
200+
if logits.ndim == 1:
180201
logits = logits.unsqueeze(0)
181202

182203
probs = F.softmax(logits, dim=-1) # [B, C]
@@ -480,7 +501,7 @@ def run(
480501

481502
to_infer = (
482503
preprocessed
483-
| 'ToKeyedTensor' >> beam.Map(lambda kv: (kv[0], kv[1]["tensor"])))
504+
| 'ToKeyedTensor' >> beam.Map(lambda kv: (kv[0], kv[1]["tensor"].float())))
484505

485506
predictions = (
486507
to_infer

0 commit comments

Comments
 (0)