{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "nbsphinx": "hidden",
    "tags": []
   },
   "outputs": [],
   "source": [
    "import open3d as o3d\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import sys\n",
    "\n",
    "# only needed for tutorial, monkey patches visualization\n",
    "sys.path.append('..')\n",
    "import open3d_tutorial\n",
    "# change to True if you want to interact with the visualization windows\n",
    "open3d_tutorial.interactive = not \"CI\" in os.environ"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Ray Casting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `RaycastingScene` class in Open3D provides basic ray casting functionality.\n",
    "In this tutorial we show how to create a scene and do ray intersection tests. \n",
    "You can also use `RaycastingScene` to create a virtual point cloud from a mesh, \n",
    "such as from a CAD model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Initialization**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As the first step we initialize a `RaycastingScene` with one or more triangle meshes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load mesh and convert to open3d.t.geometry.TriangleMesh\n",
    "cube = o3d.geometry.TriangleMesh.create_box().translate([0, 0, 0])\n",
    "cube = o3d.t.geometry.TriangleMesh.from_legacy(cube)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a scene and add the triangle mesh\n",
    "scene = o3d.t.geometry.RaycastingScene()\n",
    "cube_id = scene.add_triangles(cube)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`add_triangles()` returns the ID for the added geometry.\n",
    "This ID can be used to identify which mesh is hit by a ray."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(cube_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Casting rays**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now generate rays which are 6D vectors with origin and direction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We create two rays:\n",
    "# The first ray starts at (0.5,0.5,10) and has direction (0,0,-1).\n",
    "# The second ray start at (-1,-1,-1) and has direction (0,0,-1).\n",
    "rays = o3d.core.Tensor([[0.5, 0.5, 10, 0, 0, -1], [-1, -1, -1, 0, 0, -1]],\n",
    "                       dtype=o3d.core.Dtype.Float32)\n",
    "\n",
    "ans = scene.cast_rays(rays)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The result contains information about a possible intersection with the geometry in the scene."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(ans.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- **t_hit** is the distance to the intersection. The unit is defined by the length of the ray direction. If there is no intersection this is *inf*\n",
    "- **geometry_ids** gives the id of the geometry hit by the ray. If no geometry was hit this is `RaycastingScene.INVALID_ID`\n",
    "- **primitive_ids** is the triangle index of the triangle that was hit or `RaycastingScene.INVALID_ID`\n",
    "- **primitive_uvs** is the barycentric coordinates of the intersection point within the triangle.\n",
    "- **primitive_normals** is the normal of the hit triangle."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see from **t_hit** and **geometry_ids** that the first ray did hit the mesh but the second ray missed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(ans['t_hit'].numpy(), ans['geometry_ids'].numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Creating images**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now create a scene with multiple objects"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create meshes and convert to open3d.t.geometry.TriangleMesh\n",
    "cube = o3d.geometry.TriangleMesh.create_box().translate([0, 0, 0])\n",
    "cube = o3d.t.geometry.TriangleMesh.from_legacy(cube)\n",
    "torus = o3d.geometry.TriangleMesh.create_torus().translate([0, 0, 2])\n",
    "torus = o3d.t.geometry.TriangleMesh.from_legacy(torus)\n",
    "sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.5).translate(\n",
    "    [1, 2, 3])\n",
    "sphere = o3d.t.geometry.TriangleMesh.from_legacy(sphere)\n",
    "\n",
    "scene = o3d.t.geometry.RaycastingScene()\n",
    "scene.add_triangles(cube)\n",
    "scene.add_triangles(torus)\n",
    "_ = scene.add_triangles(sphere)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`RaycastingScene` allows to organize rays with an arbitrary number of leading dimensions.\n",
    "For instance we can generate an array with shape `[h,w,6]` to organize rays for creating an image.\n",
    "The class also provides helper functions for creating rays for a pinhole camera.\n",
    "The following creates rays Tensor with shape `[480,640,6]`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rays = o3d.t.geometry.RaycastingScene.create_rays_pinhole(\n",
    "    fov_deg=90,\n",
    "    center=[0, 0, 2],\n",
    "    eye=[2, 3, 0],\n",
    "    up=[0, 1, 0],\n",
    "    width_px=640,\n",
    "    height_px=480,\n",
    ")\n",
    "# We can directly pass the rays tensor to the cast_rays function.\n",
    "ans = scene.cast_rays(rays)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The output tensors preserve the shape of the rays and we can directly visualize the hit distance with matplotlib to get a depth map."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.imshow(ans['t_hit'].numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Further we can plot the other results to visualize the primitive normals, .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# use abs to avoid negative values\n",
    "plt.imshow(np.abs(ans['primitive_normals'].numpy()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ".. or the geometry IDs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(ans['geometry_ids'].numpy(), vmax=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating a virtual point cloud\n",
    "We can also use the hit distance to calculate the XYZ coordinates of the intersection points. These are the points that you would get by placing a virtual 3D sensor at the point of origin of the rays."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hit = ans['t_hit'].isfinite()\n",
    "points = rays[hit][:,:3] + rays[hit][:,3:]*ans['t_hit'][hit].reshape((-1,1))\n",
    "pcd = o3d.t.geometry.PointCloud(points)\n",
    "# Press Ctrl/Cmd-C in the visualization window to copy the current viewpoint\n",
    "o3d.visualization.draw_geometries([pcd.to_legacy()], \n",
    "                                  front=[0.5, 0.86, 0.125],\n",
    "                                  lookat=[0.23, 0.5, 2],\n",
    "                                  up=[-0.63, 0.45, -0.63],\n",
    "                                  zoom=0.7)\n",
    "# o3d.visualization.draw([pcd]) # new API"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
