Skip to content

Commit 212c0aa

Browse files
Consistent Trace Nesting in Parallel Function Calling (#162)
Refactors intermediate step manager, otel adaptor, and langchain callback handler to work in async scenarios. Crucially, the biggest changes came in otel adaptor to make traces nest appropriately in asyn function call scenarios (many calls in parallel). The fixes to the langchain callback handler allows for the callback intermediate step to nest the function call underneath it. The intermediate step manager needed to be refactored to prevent cross-context resets of the context variable. Ostensibly, Functionality changed from old trace like so: <img width="475" alt="old trace" src="https://github.com/user-attachments/assets/1e577ee0-dbfc-4dae-80fb-8c03a57e75fb" /> to a new trace like this: <img width="475" alt="new trace" src="https://github.com/user-attachments/assets/aeced7ae-9876-420b-8632-dd78a5fd4a1f" /> ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/NVIDIA/AgentIQ/blob/develop/docs/source/advanced/contributing.md). - We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. - Any contribution which contains commits that are not Signed-Off will not be accepted. - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. Authors: - Dhruv Nandakumar (https://github.com/dnandakumar-nv) Approvers: - Michael Demoret (https://github.com/mdemoret-nv) URL: #162
1 parent 3cc9c7b commit 212c0aa

File tree

4 files changed

+161
-11
lines changed

4 files changed

+161
-11
lines changed

‎src/aiq/builder/intermediate_step_manager.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,15 @@ def push_intermediate_step(self, payload: IntermediateStepPayload) -> None:
8181
logger.warning("Step id %s not found in outstanding start steps", payload.UUID)
8282
return
8383

84-
# If we are in the same coroutine, we should have the same parent step id. If so, unset the current step id.
85-
if (parent_step_id == payload.UUID):
86-
_current_open_step_id.reset(open_step.token)
87-
84+
# Restore the parent step ID directly instead of using a cross‑context token.
85+
if parent_step_id == payload.UUID:
86+
_current_open_step_id.set(open_step.step_parent_id)
8887
else:
89-
# Manually set the parent step ID. This happens when running on the thread pool
88+
# Different context (e.g. thread‑pool); safely restore the parent ID **without**
89+
# trying to use a token that belongs to another Context.
90+
_current_open_step_id.set(open_step.step_parent_id)
9091
parent_step_id = open_step.step_parent_id
92+
9193
elif (payload.event_state == IntermediateStepState.CHUNK):
9294

9395
# Get the current step from the outstanding steps

‎src/aiq/observability/async_otel_listener.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(self, context_state: AIQContextState | None = None):
7070
self._outstanding_spans: dict[str, Span] = {}
7171

7272
# Stack of spans, for when we need to create a child span
73-
self._span_stack: list[Span] = []
73+
self._span_stack: dict[str, Span] = {}
7474

7575
self._running = False
7676

@@ -152,7 +152,7 @@ async def _cleanup(self):
152152

153153
self._outstanding_spans.clear()
154154

155-
if self._span_stack:
155+
if len(self._span_stack) > 0:
156156
logger.error(
157157
"Not all spans were closed. Ensure all start events have a corresponding end event. Remaining: %s",
158158
self._span_stack)
@@ -175,7 +175,10 @@ def _process_start_event(self, step: IntermediateStep):
175175
parent_ctx = None
176176

177177
if (len(self._span_stack) > 0):
178-
parent_span = self._span_stack[-1]
178+
parent_span = self._span_stack.get(step.function_ancestry.parent_id, None)
179+
if parent_span is None:
180+
logger.warning("No parent span found for step %s", step.UUID)
181+
return
179182

180183
parent_ctx = set_span_in_context(parent_span)
181184

@@ -230,7 +233,7 @@ def _process_start_event(self, step: IntermediateStep):
230233
sub_span.set_attribute(SpanAttributes.INPUT_VALUE, serialized_input)
231234
sub_span.set_attribute(SpanAttributes.INPUT_MIME_TYPE, "application/json" if is_json else "text/plain")
232235

233-
self._span_stack.append(sub_span)
236+
self._span_stack[step.UUID] = sub_span
234237

235238
self._outstanding_spans[step.UUID] = sub_span
236239

@@ -243,7 +246,7 @@ def _process_end_event(self, step: IntermediateStep):
243246
logger.warning("No subspan found for step %s", step.UUID)
244247
return
245248

246-
self._span_stack.pop()
249+
self._span_stack.pop(step.UUID, None)
247250

248251
# Optionally add more attributes from usage_info or data
249252
usage_info = step.payload.usage_info

‎src/aiq/profiler/callbacks/langchain_callback_handler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback): # p
5050
completion_tokens: int = 0
5151
successful_requests: int = 0
5252
raise_error = True # Override to raise error and run inline
53-
run_inline = False
53+
run_inline = True
5454

5555
def __init__(self) -> None:
5656
super().__init__()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import contextvars
17+
import threading
18+
import uuid
19+
20+
import pytest
21+
22+
from aiq.builder.intermediate_step_manager import IntermediateStepManager
23+
from aiq.builder.intermediate_step_manager import IntermediateStepPayload
24+
from aiq.builder.intermediate_step_manager import IntermediateStepState
25+
from aiq.builder.intermediate_step_manager import _current_open_step_id
26+
27+
# --------------------------------------------------------------------------- #
28+
# Minimal stubs so the tests do not need the whole aiq code-base
29+
# --------------------------------------------------------------------------- #
30+
31+
32+
class _DummyStream(list):
33+
"""Bare-bones Observable / Subject replacement."""
34+
35+
def on_next(self, value): # reactive push
36+
self.append(value)
37+
38+
# simple subscribe: just call back synchronously
39+
def subscribe(self, on_next, on_error=None, on_complete=None):
40+
for v in self:
41+
on_next(v)
42+
return lambda: None # fake Subscription
43+
44+
45+
class _DummyFunction: # what active_function.get() returns
46+
47+
def __init__(self, name="fn", fid=None, parent_name=None):
48+
self.function_name = name
49+
self.function_id = fid or str(uuid.uuid4())
50+
self.parent_name = parent_name
51+
52+
53+
class _DummyContextState:
54+
"""Only what IntermediateStepManager touches."""
55+
56+
def __init__(self):
57+
self.active_function = contextvars.ContextVar("active_function")
58+
self.active_function.set(_DummyFunction())
59+
60+
self.event_stream = contextvars.ContextVar("event_stream")
61+
self.event_stream.set(_DummyStream())
62+
63+
64+
# --------------------------------------------------------------------------- #
65+
# Fixtures
66+
# --------------------------------------------------------------------------- #
67+
68+
69+
@pytest.fixture()
70+
def mgr():
71+
"""Fresh manager + its stubbed context-state for each test."""
72+
return IntermediateStepManager(context_state=_DummyContextState())
73+
74+
75+
def _payload(step_id=None, state=IntermediateStepState.START, name="step", etype="LLM_START"):
76+
"""Helper to create a payload with only the fields the manager uses."""
77+
return IntermediateStepPayload(
78+
UUID=step_id or str(uuid.uuid4()),
79+
name=name,
80+
event_state=state,
81+
event_type=etype,
82+
)
83+
84+
85+
# --------------------------------------------------------------------------- #
86+
# Tests
87+
# --------------------------------------------------------------------------- #
88+
89+
90+
def test_start_pushes_event_and_tracks_open_step(mgr):
91+
pay = _payload()
92+
mgr.push_intermediate_step(pay)
93+
94+
# one event captured
95+
stream = mgr._context_state.event_stream.get()
96+
assert len(stream) == 1
97+
# step now in outstanding dict
98+
assert pay.UUID in mgr._outstanding_start_steps
99+
100+
101+
def test_chunk_preserves_parent_id(mgr):
102+
pay = _payload()
103+
mgr.push_intermediate_step(pay) # START
104+
parent_id = _current_open_step_id.get()
105+
106+
chunk = _payload(step_id=pay.UUID, state=IntermediateStepState.CHUNK)
107+
mgr.push_intermediate_step(chunk)
108+
109+
# parent should still be the START id
110+
assert _current_open_step_id.get() == parent_id
111+
112+
113+
def test_end_same_context_restores_parent(mgr):
114+
pay = _payload()
115+
mgr.push_intermediate_step(pay)
116+
mgr.push_intermediate_step(_payload(step_id=pay.UUID, state=IntermediateStepState.END, etype="LLM_END"))
117+
118+
# open-step removed, ContextVar back to parent (None)
119+
assert pay.UUID not in mgr._outstanding_start_steps
120+
121+
122+
def _end_in_thread(manager, payload):
123+
"""Helper for cross-thread END."""
124+
manager.push_intermediate_step(payload)
125+
126+
127+
def test_end_other_thread_no_token_error(mgr):
128+
pay = _payload()
129+
mgr.push_intermediate_step(pay)
130+
131+
end_pay = _payload(step_id=pay.UUID, state=IntermediateStepState.END, etype="LLM_END")
132+
t = threading.Thread(target=_end_in_thread, args=(mgr, end_pay))
133+
t.start()
134+
t.join()
135+
136+
# still cleaned up
137+
assert pay.UUID not in mgr._outstanding_start_steps
138+
139+
140+
def test_mismatched_chunk_logs_warning(mgr, caplog):
141+
# CHUNK without START
142+
chunk = _payload(state=IntermediateStepState.CHUNK, etype="LLM_NEW_TOKEN")
143+
mgr.push_intermediate_step(chunk)
144+
145+
assert "no matching start step" in caplog.text.lower()

0 commit comments

Comments
 (0)