Skip to content

Commit

Permalink
[Improvement] make description optional (#306)
Browse files Browse the repository at this point in the history
* make description optional

* disable parameter validation

* disable log detecting
  • Loading branch information
wj-Mcat committed Jan 24, 2024
1 parent c0b04d8 commit 87e4afb
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 39 deletions.
24 changes: 13 additions & 11 deletions erniebot-agent/src/erniebot_agent/tools/remote_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,19 @@ async def __post_process__(self, tool_response: dict) -> dict:
"请务必确保每个符合'file-'格式的字段只出现一次,无需将其转换为链接,也无需添加任何HTML、Markdown或其他格式化元素。"
)

if self.tool_view.returns is not None:
try:
origin_tool_response = deepcopy(tool_response)
valid_tool_response = self.tool_view.returns(**origin_tool_response).model_dump(mode="json")
tool_response.update(valid_tool_response)
except Exception as e:
_logger.warning(
"Unable to validate the 'tool_response' against the schema defined in the YAML file. "
f"The specific error encountered is: '<{e}>'. "
"As a result, the original response from the tool will be used.",
)
# if self.tool_view.returns is not None:
# try:
# origin_tool_response = deepcopy(tool_response)
# valid_tool_response = self.tool_view.returns(
# **origin_tool_response
# ).model_dump(mode="json")
# tool_response.update(valid_tool_response)
# except Exception as e:
# _logger.warning(
# "Unable to validate the 'tool_response' against the schema defined in the YAML file. "
# f"The specific error encountered is: '<{e}>'. "
# "As a result, the original response from the tool will be used.",
# )
return tool_response

async def __call__(self, **tool_arguments: Dict[str, Any]) -> Any:
Expand Down
3 changes: 0 additions & 3 deletions erniebot-agent/src/erniebot_agent/tools/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,6 @@ def from_openapi_dict(cls, schema: dict) -> Type[ToolParameterView]:
if "type" not in field_dict:
raise ToolError(f"`type` field not found in `{field_name}` property", stage="Loading")

if "description" not in field_dict:
raise ToolError(f"`description` field not found in `{field_name}` property", stage="Loading")

if field_name.startswith("__"):
continue

Expand Down
50 changes: 25 additions & 25 deletions erniebot-agent/tests/unit_tests/tools/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,31 +544,31 @@ async def test_enum_v1(self):
self.assertEqual(result["enum_field"], "2")
self.assertEqual(result["no_enum_field"], "no_enum_value")

@responses.activate
async def test_enum_v1_with_wrong_dtype(self):
tool = self.toolkit.get_tool("enum_v1")

responses.post(
"http://example.com/enum_v1_dtype",
json={"enum_field": 2, "no_enum_field": "no_enum_value"},
)

tool.tool_view.uri = "enum_v1_dtype"
with self.assertLogs("erniebot_agent.tools.remote_tool", level="INFO") as cm:
result = await tool()

logs = [item for item in cm.output if "Unable to validate the 'tool_response'" in item]

# test raise warning log msg
self.assertEqual(len(logs), 1)
warning_log_msg = (
"Unable to validate the 'tool_response' against the schema defined "
"in the YAML file. The specific error encountered is: '<1 validation error for "
)
self.assertIn(warning_log_msg, logs[0])

self.assertEqual(result["enum_field"], 2)
self.assertEqual(result["no_enum_field"], "no_enum_value")
# @responses.activate
# async def test_enum_v1_with_wrong_dtype(self):
# tool = self.toolkit.get_tool("enum_v1")

# responses.post(
# "http://example.com/enum_v1_dtype",
# json={"enum_field": 2, "no_enum_field": "no_enum_value"},
# )

# tool.tool_view.uri = "enum_v1_dtype"
# with self.assertLogs("erniebot_agent.tools.remote_tool", level="INFO") as cm:
# result = await tool()

# logs = [item for item in cm.output if "Unable to validate the 'tool_response'" in item]

# # test raise warning log msg
# self.assertEqual(len(logs), 1)
# warning_log_msg = (
# "Unable to validate the 'tool_response' against the schema defined "
# "in the YAML file. The specific error encountered is: '<1 validation error for "
# )
# self.assertIn(warning_log_msg, logs[0])

# self.assertEqual(result["enum_field"], 2)
# self.assertEqual(result["no_enum_field"], "no_enum_value")

@responses.activate
async def test_enum_v2(self):
Expand Down

0 comments on commit 87e4afb

Please sign in to comment.