Skip to content

Commit

Permalink
[BugFixes] Fix file testing (#302)
Browse files Browse the repository at this point in the history
* update file-schema testing

* update file-testing

* fix byte output

* fix lint
  • Loading branch information
wj-Mcat committed Jan 24, 2024
1 parent 885addb commit c0b04d8
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 18 deletions.
36 changes: 21 additions & 15 deletions erniebot-agent/src/erniebot_agent/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,26 @@ async def create_file_from_data(
)


async def get_content_by_file_id(
file_id: str, format: str, mime_type: str, file_manager: FileManager
) -> bytes:
async def get_content_by_file_id(file_id: str, format: str, file_manager: FileManager) -> str:
file_id = file_id.replace("<file>", "").replace("</file>", "")
file = file_manager.look_up_file_by_id(file_id)
byte_str = await file.read_contents()
return byte_str
if format == "byte":
byte_str = base64.b64encode(byte_str)

return byte_str.decode()


def is_file_config(json_schema_extra: dict) -> bool:
"""check wheter is file-config
Args:
json_schema_extra (dict): the config from yaml file
Returns:
bool: whether is file-config
"""
return json_schema_extra.get("format", None) in ["byte", "binary"]


@no_type_check
Expand All @@ -233,13 +246,12 @@ async def parse_json_request(
"Please check the format of yaml in current tool."
)

if model_field.annotation == str and "x-ebagent-file-mime-type" in model_field.json_schema_extra:
if model_field.annotation == str and is_file_config(model_field.json_schema_extra):
format = model_field.json_schema_extra.get("format", None)
mime_type = model_field.json_schema_extra.get("x-ebagent-file-mime-type", None)

if format is not None and mime_type is not None:
if format is not None:
file_content = await get_content_by_file_id(
json_dict[field_name], format=format, mime_type=mime_type, file_manager=file_manager
json_dict[field_name], format=format, file_manager=file_manager
)
result[field_name] = file_content
elif issubclass(model_field.annotation, ToolParameterView):
Expand All @@ -254,19 +266,13 @@ async def parse_json_request(
else:
array_json_schema = model_field.json_schema_extra.get("array_items_schema", None)
sub_class = get_args(model_field.annotation)[0]
if (
list_type == "string"
and array_json_schema is not None
and array_json_schema.get("x-ebagent-file-mime-type", None)
):
if list_type == "string" and is_file_config(array_json_schema):
format = array_json_schema["format"]
mime_type = array_json_schema["x-ebagent-file-mime-type"]
files = []
for file_id in json_dict[field_name]:
file_content = await get_content_by_file_id(
file_id,
format=format,
mime_type=mime_type,
file_manager=file_manager,
)
files.append(file_content)
Expand Down
63 changes: 61 additions & 2 deletions erniebot-agent/tests/fixtures/openapis/file.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,35 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/file_v6"
/file_v7:
post:
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/file_v7_input"
responses:
"200":
description: 列表展示完成
content:
application/json:
schema:
$ref: "#/components/schemas/file_v7_output"
/file_v8:
post:
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/file_v8"
responses:
"200":
description: 列表展示完成
content:
application/json:
schema: {}
components:
schemas:
# {file: ["base64", "base64-string"]}
Expand Down Expand Up @@ -143,7 +172,6 @@ components:
not_file_field:
type: string
description: not-file-content

# {first_file: "base64-string", second_file: "base64-string"}
file_v6:
type: object
Expand All @@ -157,4 +185,35 @@ components:
type: string
description: string
format: byte
x-ebagent-file-mime-type: image/png
x-ebagent-file-mime-type: image/png
# {first_file: "base64-string", second_file: "base64-string"}
file_v7_input:
type: object
properties:
first_file:
type: string
description: string
format: byte
x-ebagent-file-mime-type: image/png
file_v7_output:
type: object
properties:
second_file:
type: string
description: string
format: byte
x-ebagent-file-mime-type: image/png
file_v8:
type: object
required: [file]
properties:
file:
type: array
items:
type: string
format: byte
description: 单词本单词列表
not_file_field:
type: string
description: not_file_field
default: "222"
46 changes: 45 additions & 1 deletion erniebot-agent/tests/unit_tests/tools/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
is_optional_type,
json_type,
)
from erniebot_agent.tools.utils import tool_response_contains_file
from erniebot_agent.tools.utils import parse_json_request, tool_response_contains_file
from erniebot_agent.utils.common import create_enum_class


Expand Down Expand Up @@ -483,6 +483,50 @@ async def test_file_v6(self):
file_content = base64.b64decode(file_content_1)
self.assertEqual(file_content, file_content_from_file_manager)

@responses.activate
async def test_file_v7(self):
tool = self.toolkit.get_tool("file_v7")

file_manager = GlobalFileManagerHandler().get()
file_content = str(uuid4())

file_content_base64 = base64.b64encode(file_content.encode())
file = await file_manager.create_file_from_bytes(file_content_base64, filename="a.png")

responses.post("http://example.com/file_v7", json={"second_file": file_content_base64.decode()})

result = await tool(first_file=file.id)

file: File = file_manager.look_up_file_by_id(result["second_file"])
file_content_from_file_manager = await file.read_contents()
file_content = base64.b64decode(file_content_base64)
self.assertEqual(file_content, file_content_from_file_manager)

@responses.activate
async def test_file_v8(self):
tool = self.toolkit.get_tool("file_v8")
file_manager = GlobalFileManagerHandler().get()

file_ids, file_contents = [], []
for _ in range(1):
file_content = str(uuid4()).encode()
file_content_base64 = base64.b64encode(file_content)
file_contents.append(file_content_base64.decode())

file = await file_manager.create_file_from_bytes(file_content, filename="a.png")
file_ids.append(file.id)

responses.post("http://example.com/file_v8", json={})
self.assertIsNotNone(tool.tool_view.parameters)
tool_arguments = await parse_json_request(
tool.tool_view.parameters, {"file": file_ids}, file_manager=file_manager
)
for index, file_content in enumerate(tool_arguments["file"]):
self.assertEqual(file_contents[index], file_content)

result = await tool(file=file_ids)
self.assertEqual(len(result), 0)


class TestEnumSchema(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None:
Expand Down

0 comments on commit c0b04d8

Please sign in to comment.