diff --git a/doc/jsk_perception/nodes/classification_node.md b/doc/jsk_perception/nodes/classification_node.md
index 3262618c48..f038526fb1 100644
--- a/doc/jsk_perception/nodes/classification_node.md
+++ b/doc/jsk_perception/nodes/classification_node.md
@@ -2,7 +2,7 @@
![](images/clip.png)
-The ROS node for Classification with CLIP.
+The ROS node for Classification with CLIP or ImageBind.
## System Configuration
![](images/large_scale_vil_system.png)
@@ -65,19 +65,29 @@ make
You can send multiple queries with separating semicolon.
### Run inference container on another host or another terminal
+Now you can use CLIP or ImageBind.
+
+#### If you want to use CLIP.
In the remote GPU machine,
```shell
cd jsk_recognition/jsk_perception/docker
./run_jsk_vil_api clip --port (Your vacant port)
```
+#### If you want to use ImageBind.
+In the remote GPU machine,
+```shell
+cd jsk_recognition/jsk_perception/docker
+./run_jsk_vil_api image-bind --port (Your vacant port)
+```
+
In the ROS machine,
```shell
-roslaunch jsk_perception classification.launch port:=(Your inference container port) host:=(Your inference container host) CLASSIFICATION_INPUT_IMAGE:=(Your image topic name) gui:=true
+roslaunch jsk_perception classification.launch port:=(Your inference container port) host:=(Your inference container host) CLASSIFICATION_INPUT_IMAGE:=(Your image topic name) model:=(Your using model's name) gui:=true
```
### Run both inference container and ros node in single host
```
-roslaunch jsk_perception classification.launch run_api:=true CLASSIFICATION_INPUT_IMAGE:=(Your image topic name) gui:=true
+roslaunch jsk_perception classification.launch run_api:=true CLASSIFICATION_INPUT_IMAGE:=(Your image topic name) model:=(Your using model's name) gui:=true
```
diff --git a/jsk_perception/docker/Makefile b/jsk_perception/docker/Makefile
index 83eb2800ab..b070b74639 100644
--- a/jsk_perception/docker/Makefile
+++ b/jsk_perception/docker/Makefile
@@ -5,9 +5,11 @@
# api directories
OFAPROJECT = ofa
CLIPPROJECT = clip
+IMAGEBINDPROJECT = image-bind
# image names
OFAIMAGE = jsk-ofa-server
CLIPIMAGE = jsk-clip-server
+IMAGEBINDIMAGE = jsk-image-bind-server
# commands
BUILDIMAGE = docker build
REMOVEIMAGE = docker rmi
@@ -23,7 +25,7 @@ PARAMURLS = parameter_urls.txt
# OFA parameters
OFAPARAMFILES = $(foreach param, $(OFAPARAMS), $(PARAMDIR)/$(param))
-all: ofa clip
+all: ofa clip image-bind
# TODO check command wget exists, nvidia-driver version
@@ -41,6 +43,9 @@ ofa: $(PARAMDIR)/.download
clip: $(PARAMDIR)/.download
$(BUILDIMAGE) $(CLIPPROJECT) -t $(CLIPIMAGE) -f $(CLIPPROJECT)/Dockerfile
+image-bind: $(PARAMDIR)/.download
+ $(BUILDIMAGE) $(IMAGEBINDPROJECT) -t $(IMAGEBINDIMAGE) -f $(IMAGEBINDPROJECT)/Dockerfile
+
# TODO add clip, glip
clean:
@$(REMOVEIMAGE) $(OFAIMAGE)
@@ -48,4 +53,4 @@ clean:
wipe: clean
rm -fr $(PARAMDIR)
-.PHONY: clean wipe ofa clip
+.PHONY: clean wipe ofa clip image-bind
diff --git a/jsk_perception/docker/image-bind/Dockerfile b/jsk_perception/docker/image-bind/Dockerfile
new file mode 100644
index 0000000000..96a0d0e094
--- /dev/null
+++ b/jsk_perception/docker/image-bind/Dockerfile
@@ -0,0 +1,32 @@
+FROM pytorch/pytorch:1.9.1-cuda11.1-cudnn8-devel
+ARG DEBIAN_FRONTEND=noninteractive
+RUN apt -o Acquire::AllowInsecureRepositories=true update \
+ && apt-get install -y \
+ curl \
+ git \
+ libopencv-dev \
+ wget \
+ emacs \
+ python3.8 \
+ python3-dev \
+ libproj-dev \
+ proj-data \
+ proj-bin \
+ libgeos-dev \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*
+ENV CUDA_HOME /usr/local/cuda
+ENV TORCH_CUDA_ARCH_LIST 8.0+PTX
+RUN git clone https://github.com/Kanazawanaoaki/ImageBind.git -b update-data-load
+RUN echo 'export CUDA_HOME=/usr/local/cuda' >> ~/.bashrc
+RUN echo 'TORCH_CUDA_ARCH_LIST=8.0+PTX' >> ~/.bashrc
+RUN pip install flask opencv-python \
+ && pip install soundfile \
+ && pip install --upgrade pip setuptools wheel \
+ && pip install cartopy==0.19.0.post1
+RUN cd ImageBind \
+ && git pull origin update-data-load \
+ && pip install -r requirements.txt \
+ && pip install -e .
+COPY server.py /workspace/ImageBind
+ENTRYPOINT cd /workspace/ImageBind && python server.py
\ No newline at end of file
diff --git a/jsk_perception/docker/image-bind/server.py b/jsk_perception/docker/image-bind/server.py
new file mode 100644
index 0000000000..03b36f915f
--- /dev/null
+++ b/jsk_perception/docker/image-bind/server.py
@@ -0,0 +1,166 @@
+from imagebind import data
+from imagebind.models import imagebind_model
+from imagebind.models.imagebind_model import ModalityType
+from pytorchvideo.data.encoded_video_decord import EncodedVideoDecord
+import io
+
+import cv2
+import numpy as np
+from PIL import Image as PLImage
+import torch
+
+# web server
+from flask import Flask, request, Response
+import json
+import base64
+
+def apply_half(t):
+ if t.dtype is torch.float32:
+ return t.to(dtype=torch.half)
+ return t
+
+class Inference:
+ def __init__(self, modal, gpu_id=None):
+ self.gpu_id = gpu_id
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ self.modal_name = modal
+
+ self.model = imagebind_model.imagebind_huge(pretrained=True)
+ self.model.eval()
+ self.model.to(self.device)
+
+ self.video_sample_rate=16000
+
+ def convert_to_string(self, input_list):
+ output_string = ""
+ for item in input_list:
+ output_string += item + " . "
+ return output_string.strip()
+
+ def infer(self, msg, texts):
+ text_inputs = texts
+
+ if self.modal_name == "image":
+ # get cv2 image
+ # image = cv2.resize(img, dsize=(640, 480)) # NOTE forcely
+ # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ image = cv2.cvtColor(msg, cv2.COLOR_BGR2RGB)
+ image = PLImage.fromarray(image)
+
+ image_input = [image]
+
+ inputs = {
+ ModalityType.TEXT: data.load_and_transform_text(text_inputs, self.device),
+ ModalityType.VISION: data.load_and_transform_vision_data(None, self.device, image_input),
+ }
+ modal_data_type = ModalityType.VISION
+
+ elif self.modal_name == "video":
+ import decord
+ decord.bridge.set_bridge("torch")
+ video_io = io.BytesIO(msg)
+ video = EncodedVideoDecord(file=video_io,
+ video_name="current_video_data",
+ decode_video=True,
+ decode_audio=False,
+ **{"sample_rate": self.video_sample_rate},
+ )
+
+ inputs = {
+ ModalityType.TEXT: data.load_and_transform_text(text_inputs, self.device),
+ ModalityType.VISION: data.load_and_transform_video_data(None, self.device, videos=[video]),
+ }
+ modal_data_type = ModalityType.VISION
+
+ elif self.modal_name == "audio":
+ waveform = msg["waveform"]
+ sr = msg["sr"]
+ waveform_np = np.frombuffer(waveform, dtype=np.float32)
+ waveform_torch = torch.tensor(waveform_np.reshape(1, -1))
+
+ inputs = {
+ ModalityType.TEXT: data.load_and_transform_text(text_inputs, self.device),
+ ModalityType.AUDIO: data.load_and_transform_audio_data(None, self.device, audios=[{"waveform": waveform_torch, "sr": sr}]),
+ }
+ modal_data_type = ModalityType.AUDIO
+
+ # Calculate features
+ with torch.no_grad():
+ embeddings = self.model(inputs)
+
+ similarity = np.average((embeddings[modal_data_type] @ embeddings[ModalityType.TEXT].T).tolist(), axis=0)
+ probability = torch.softmax(embeddings[modal_data_type] @ embeddings[ModalityType.TEXT].T, dim=-1)
+
+ values, indices = probability[0].topk(len(texts))
+ results = {}
+ for value, index in zip(values, indices):
+ results[texts[index]] = (value.item(), float(similarity[index]))
+ return results
+
+# run
+if __name__ == "__main__":
+ app = Flask(__name__)
+
+ image_infer = Inference("image")
+ video_infer = Inference("video")
+ audio_infer = Inference("audio")
+
+ try:
+ @app.route("/inference", methods=['POST'])
+ def image_request():
+ data = request.data.decode("utf-8")
+ data_json = json.loads(data)
+ # process image
+ image_b = data_json['image']
+ image_dec = base64.b64decode(image_b)
+ data_np = np.fromstring(image_dec, dtype='uint8')
+ img = cv2.imdecode(data_np, 1)
+ # get text
+ texts = data_json['queries']
+ infer_results = image_infer.infer(img, texts)
+ results = []
+ for q in infer_results:
+ results.append({"question": q, "probability": infer_results[q][0], "similarity": infer_results[q][1]})
+ return Response(response=json.dumps({"results": results}), status=200)
+ except NameError:
+ print("Skipping create inference app")
+
+ try:
+ @app.route("/video_class", methods=['POST'])
+ def video_request():
+ data = request.data.decode("utf-8")
+ data_json = json.loads(data)
+ # process image
+ video_b = data_json['video']
+ video_dec = base64.b64decode(video_b)
+ # get text
+ texts = data_json['queries']
+ infer_results = video_infer.infer(video_dec, texts)
+ results = []
+ for q in infer_results:
+ results.append({"question": q, "probability": infer_results[q][0], "similarity": infer_results[q][1]})
+ return Response(response=json.dumps({"results": results}), status=200)
+ except NameError:
+ print("Skipping create video_class app")
+
+ try:
+ @app.route("/audio_class", methods=['POST'])
+ def audio_request():
+ data = request.data.decode("utf-8")
+ data_json = json.loads(data)
+ # process image
+ audio_b = data_json['audio']
+ sr = data_json['sr']
+ audio_dec = base64.b64decode(audio_b)
+ # get text
+ texts = data_json['queries']
+ infer_results = audio_infer.infer({"waveform": audio_dec, "sr": sr}, texts)
+ results = []
+ for q in infer_results:
+ results.append({"question": q, "probability": infer_results[q][0], "similarity": infer_results[q][1]})
+ return Response(response=json.dumps({"results": results}), status=200)
+ except NameError:
+ print("Skipping create audio_class app")
+
+ app.run("0.0.0.0", 8080, threaded=True)
diff --git a/jsk_perception/docker/run_jsk_vil_api b/jsk_perception/docker/run_jsk_vil_api
index acef636280..4951899850 100755
--- a/jsk_perception/docker/run_jsk_vil_api
+++ b/jsk_perception/docker/run_jsk_vil_api
@@ -10,7 +10,8 @@ import subprocess
import sys
CONTAINERS = {"ofa": "jsk-ofa-server",
- "clip": "jsk-clip-server"}
+ "clip": "jsk-clip-server",
+ "image-bind": "jsk-image-bind-server"}
OFA_MODEL_SCALES = ["base", "large", "huge"]
parser = argparse.ArgumentParser(description="JSK Vision and Language API runner")
diff --git a/jsk_perception/launch/classification.launch b/jsk_perception/launch/classification.launch
index fbf14fae0e..3403c87249 100644
--- a/jsk_perception/launch/classification.launch
+++ b/jsk_perception/launch/classification.launch
@@ -4,17 +4,19 @@
+
+ args="(arg model) -p $(arg port)" if="$(arg run_api)" />
host: $(arg host)
port: $(arg port)
+ model: $(arg model)
image_transport: $(arg image_transport)
diff --git a/jsk_perception/src/jsk_perception/vil_inference_client.py b/jsk_perception/src/jsk_perception/vil_inference_client.py
index 02026af4ff..3f04f175b0 100644
--- a/jsk_perception/src/jsk_perception/vil_inference_client.py
+++ b/jsk_perception/src/jsk_perception/vil_inference_client.py
@@ -152,6 +152,7 @@ def __init__(self):
ClassificationTaskFeedback,
ClassificationTaskResult,
"inference")
+ self.model_name = rospy.get_param("~model", default="clip")
def topic_cb(self, data):
if not self.config: rospy.logwarn("No queries"); return
@@ -199,7 +200,7 @@ def inference(self, img_msg, queries):
msg.label_names = labels
msg.label_proba = similarities # cosine similarities
msg.probabilities = probabilities # sum(probabilities) is 1
- msg.classifier = 'clip'
+ msg.classifier = self.model_name
msg.target_names = queries
return msg