diff --git a/src/scriptworker/context.py b/src/scriptworker/context.py index f60059e5..a57eb0f3 100644 --- a/src/scriptworker/context.py +++ b/src/scriptworker/context.py @@ -59,6 +59,7 @@ class Context(object): config: Optional[Dict[str, Any]] = None credentials_timestamp: Optional[int] = None + credentials_fd: int = -1 proc: Optional[task_process.TaskProcess] = None queue: Optional[Queue] = None session: Optional[aiohttp.ClientSession] = None @@ -98,6 +99,8 @@ def claim_task(self, claim_task: Optional[Dict[str, Any]]) -> None: if claim_task: self.task = claim_task["task"] self.verify_task() + # flags=0 to let the child inherit this fd + self.credentials_fd = os.memfd_create("scriptworker_temp_creds", flags=0) self.temp_credentials = claim_task["credentials"] path = os.path.join(self.config["work_dir"], "task.json") assert self.task @@ -105,6 +108,8 @@ def claim_task(self, claim_task: Optional[Dict[str, Any]]) -> None: else: self.temp_credentials = None self.task = None + os.close(self.credentials_fd) + self.credentials_fd = -1 def verify_task(self) -> None: """Run some task sanity checks on ``self.task``.""" @@ -193,6 +198,10 @@ def temp_credentials(self) -> Optional[Dict[str, Any]]: def temp_credentials(self, credentials: Optional[Dict[str, Any]]) -> None: self._temp_credentials = credentials self.temp_queue = self.create_queue(self.temp_credentials) + if credentials: + data = json.dumps(credentials, indent=2, sort_keys=True).encode("ascii") + # use pwrite so we don't confuse the child by changing the file offset + assert os.pwrite(self.credentials_fd, data, 0) == len(data) def write_json(self, path: str, contents: Dict[str, Any], message: str) -> None: """Write json to disk. diff --git a/src/scriptworker/task.py b/src/scriptworker/task.py index 4f4751ab..ad0e8d69 100644 --- a/src/scriptworker/task.py +++ b/src/scriptworker/task.py @@ -672,7 +672,16 @@ async def run_task(context, to_cancellable_process): env["TASK_ID"] = context.task_id or "None" env["RUN_ID"] = str(get_run_id(context.claim_task)) env["TASKCLUSTER_ROOT_URL"] = context.config["taskcluster_root_url"] - kwargs = {"stdout": PIPE, "stderr": PIPE, "stdin": None, "close_fds": True, "preexec_fn": lambda: os.setsid(), "env": env} # pragma: no branch + env["TASKCLUSTER_CREDENTIALS_FD"] = str(context.credentials_fd) + kwargs = { + "stdout": PIPE, + "stderr": PIPE, + "stdin": None, + "close_fds": True, + "preexec_fn": lambda: os.setsid(), + "env": env, + "pass_fds": (context.credentials_fd,), + } # pragma: no branch timeout = get_task_maxruntime(context.task, context.config["task_max_timeout"]) subprocess = await asyncio.create_subprocess_exec(*context.config["task_script"], **kwargs) diff --git a/tests/test_context.py b/tests/test_context.py index ff534cf8..5549262c 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -77,6 +77,46 @@ async def test_set_reset_task(rw_context, claim_task, reclaim_task): assert rw_context.temp_queue is None +def test_credentials_fd_initial(rw_context): + assert rw_context.credentials_fd == -1 + + +@pytest.mark.asyncio +async def test_credentials_fd_opened_on_claim_task(rw_context, claim_task): + rw_context.claim_task = claim_task + assert rw_context.credentials_fd >= 0 + os.fstat(rw_context.credentials_fd) # raises OSError if fd is invalid + + +@pytest.mark.asyncio +async def test_credentials_fd_content(rw_context, claim_task): + rw_context.claim_task = claim_task + fd = rw_context.credentials_fd + size = os.fstat(fd).st_size + data = os.pread(fd, size, 0) + assert json.loads(data) == claim_task["credentials"] + + +@pytest.mark.asyncio +async def test_credentials_fd_updated_on_reclaim(rw_context, claim_task, reclaim_task): + rw_context.claim_task = claim_task + rw_context.reclaim_task = reclaim_task + fd = rw_context.credentials_fd + size = os.fstat(fd).st_size + data = os.pread(fd, size, 0) + assert json.loads(data) == reclaim_task["credentials"] + + +@pytest.mark.asyncio +async def test_credentials_fd_closed_on_reset(rw_context, claim_task): + rw_context.claim_task = claim_task + fd = rw_context.credentials_fd + rw_context.claim_task = None + assert rw_context.credentials_fd == -1 + with pytest.raises(OSError): + os.fstat(fd) + + @pytest.mark.asyncio async def test_projects(rw_context, mocker): fake_projects = {"mozilla-central": "blah", "count": 0} diff --git a/tests/test_task.py b/tests/test_task.py index 1b38e324..121a521c 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -567,6 +567,24 @@ async def test_run_task_timeout(context): assert context.proc is None +@pytest.mark.asyncio +async def test_run_task_credentials_fd(context): + """The subprocess receives the temp credentials via the fd in TASKCLUSTER_CREDENTIALS_FD.""" + context.config["task_script"] = ( + sys.executable, + "-c", + "import json, os, sys; " + "fd = int(os.environ['TASKCLUSTER_CREDENTIALS_FD']); " + "size = os.fstat(fd).st_size; " + "sys.stdout.write(os.pread(fd, size, 0).decode('ascii'))", + ) + await swtask.run_task(context, noop_to_cancellable_process) + log_file = log.get_log_filename(context) + contents = read(log_file) + parsed, _ = json.JSONDecoder().raw_decode(contents) + assert parsed == context.temp_credentials + + # report* {{{1 @pytest.mark.asyncio async def test_reportCompleted(context, successful_queue):